1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <math.h>
14 #include <stdbool.h>
15 
16 #include "config/aom_dsp_rtcd.h"
17 #include "config/av1_rtcd.h"
18 
19 #include "aom_dsp/aom_dsp_common.h"
20 #include "aom_dsp/blend.h"
21 #include "aom_mem/aom_mem.h"
22 #include "aom_ports/aom_timer.h"
23 #include "aom_ports/mem.h"
24 #include "aom_ports/system_state.h"
25 
26 #include "av1/common/cfl.h"
27 #include "av1/common/common.h"
28 #include "av1/common/common_data.h"
29 #include "av1/common/entropy.h"
30 #include "av1/common/entropymode.h"
31 #include "av1/common/idct.h"
32 #include "av1/common/mvref_common.h"
33 #include "av1/common/obmc.h"
34 #include "av1/common/onyxc_int.h"
35 #include "av1/common/pred_common.h"
36 #include "av1/common/quant_common.h"
37 #include "av1/common/reconinter.h"
38 #include "av1/common/reconintra.h"
39 #include "av1/common/scan.h"
40 #include "av1/common/seg_common.h"
41 #include "av1/common/txb_common.h"
42 #include "av1/common/warped_motion.h"
43 
44 #include "av1/encoder/aq_variance.h"
45 #include "av1/encoder/av1_quantize.h"
46 #include "av1/encoder/cost.h"
47 #include "av1/encoder/encodemb.h"
48 #include "av1/encoder/encodemv.h"
49 #include "av1/encoder/encoder.h"
50 #include "av1/encoder/encodetxb.h"
51 #include "av1/encoder/hybrid_fwd_txfm.h"
52 #include "av1/encoder/mcomp.h"
53 #include "av1/encoder/ml.h"
54 #include "av1/encoder/palette.h"
55 #include "av1/encoder/pustats.h"
56 #include "av1/encoder/random.h"
57 #include "av1/encoder/ratectrl.h"
58 #include "av1/encoder/rd.h"
59 #include "av1/encoder/rdopt.h"
60 #include "av1/encoder/reconinter_enc.h"
61 #include "av1/encoder/tokenize.h"
62 #include "av1/encoder/tx_prune_model_weights.h"
63 
64 // Set this macro as 1 to collect data about tx size selection.
65 #define COLLECT_TX_SIZE_DATA 0
66 
67 #if COLLECT_TX_SIZE_DATA
68 static const char av1_tx_size_data_output_file[] = "tx_size_data.txt";
69 #endif
70 
71 typedef void (*model_rd_for_sb_type)(
72     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
73     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
74     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
75     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
76 typedef void (*model_rd_from_sse_type)(const AV1_COMP *const cpi,
77                                        const MACROBLOCK *const x,
78                                        BLOCK_SIZE plane_bsize, int plane,
79                                        int64_t sse, int num_samples, int *rate,
80                                        int64_t *dist);
81 
82 static void model_rd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
83                             MACROBLOCK *x, MACROBLOCKD *xd, int plane_from,
84                             int plane_to, int mi_row, int mi_col,
85                             int *out_rate_sum, int64_t *out_dist_sum,
86                             int *skip_txfm_sb, int64_t *skip_sse_sb,
87                             int *plane_rate, int64_t *plane_sse,
88                             int64_t *plane_dist);
89 static void model_rd_for_sb_with_curvfit(
90     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
91     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
92     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
93     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
94 static void model_rd_for_sb_with_surffit(
95     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
96     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
97     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
98     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
99 static void model_rd_for_sb_with_dnn(
100     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
101     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
102     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
103     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
104 static void model_rd_for_sb_with_fullrdy(
105     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
106     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
107     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
108     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
109 static void model_rd_from_sse(const AV1_COMP *const cpi,
110                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
111                               int plane, int64_t sse, int num_samples,
112                               int *rate, int64_t *dist);
113 static void model_rd_with_dnn(const AV1_COMP *const cpi,
114                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
115                               int plane, int64_t sse, int num_samples,
116                               int *rate, int64_t *dist);
117 static void model_rd_with_curvfit(const AV1_COMP *const cpi,
118                                   const MACROBLOCK *const x,
119                                   BLOCK_SIZE plane_bsize, int plane,
120                                   int64_t sse, int num_samples, int *rate,
121                                   int64_t *dist);
122 static void model_rd_with_surffit(const AV1_COMP *const cpi,
123                                   const MACROBLOCK *const x,
124                                   BLOCK_SIZE plane_bsize, int plane,
125                                   int64_t sse, int num_samples, int *rate,
126                                   int64_t *dist);
127 
128 enum {
129   MODELRD_LEGACY,
130   MODELRD_CURVFIT,
131   MODELRD_SUFFIT,
132   MODELRD_DNN,
133   MODELRD_FULLRDY,
134   MODELRD_TYPES
135 } UENUM1BYTE(ModelRdType);
136 
137 static model_rd_for_sb_type model_rd_sb_fn[MODELRD_TYPES] = {
138   model_rd_for_sb, model_rd_for_sb_with_curvfit, model_rd_for_sb_with_surffit,
139   model_rd_for_sb_with_dnn, model_rd_for_sb_with_fullrdy
140 };
141 
142 static model_rd_from_sse_type model_rd_sse_fn[MODELRD_TYPES] = {
143   model_rd_from_sse, model_rd_with_curvfit, model_rd_with_surffit,
144   model_rd_with_dnn, NULL
145 };
146 
147 // 0: Legacy model
148 // 1: Curve fit model
149 // 2: Surface fit model
150 // 3: DNN regression model
151 // 4: Full rd model
152 #define MODELRD_TYPE_INTERP_FILTER 1
153 #define MODELRD_TYPE_TX_SEARCH_PRUNE 1
154 #define MODELRD_TYPE_MASKED_COMPOUND 1
155 #define MODELRD_TYPE_INTERINTRA 1
156 #define MODELRD_TYPE_INTRA 1
157 #define MODELRD_TYPE_DIST_WTD_COMPOUND 1
158 #define MODELRD_TYPE_MOTION_MODE_RD 1
159 
160 #define DUAL_FILTER_SET_SIZE (SWITCHABLE_FILTERS * SWITCHABLE_FILTERS)
161 static const InterpFilters filter_sets[DUAL_FILTER_SET_SIZE] = {
162   0x00000000, 0x00010000, 0x00020000,  // y = 0
163   0x00000001, 0x00010001, 0x00020001,  // y = 1
164   0x00000002, 0x00010002, 0x00020002,  // y = 2
165 };
166 
167 static const double ADST_FLIP_SVM[8] = {
168   /* vertical */
169   -6.6623, -2.8062, -3.2531, 3.1671,
170   /* horizontal */
171   -7.7051, -3.2234, -3.6193, 3.4533
172 };
173 
174 typedef struct {
175   PREDICTION_MODE mode;
176   MV_REFERENCE_FRAME ref_frame[2];
177 } MODE_DEFINITION;
178 
179 enum {
180   FTXS_NONE = 0,
181   FTXS_DCT_AND_1D_DCT_ONLY = 1 << 0,
182   FTXS_DISABLE_TRELLIS_OPT = 1 << 1,
183   FTXS_USE_TRANSFORM_DOMAIN = 1 << 2
184 } UENUM1BYTE(FAST_TX_SEARCH_MODE);
185 
186 struct rdcost_block_args {
187   const AV1_COMP *cpi;
188   MACROBLOCK *x;
189   ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
190   ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
191   RD_STATS rd_stats;
192   int64_t this_rd;
193   int64_t best_rd;
194   int exit_early;
195   int incomplete_exit;
196   int use_fast_coef_costing;
197   FAST_TX_SEARCH_MODE ftxs_mode;
198   int skip_trellis;
199 };
200 
201 #define LAST_NEW_MV_INDEX 6
202 static const MODE_DEFINITION av1_mode_order[MAX_MODES] = {
203   { NEARESTMV, { LAST_FRAME, NONE_FRAME } },
204   { NEARESTMV, { LAST2_FRAME, NONE_FRAME } },
205   { NEARESTMV, { LAST3_FRAME, NONE_FRAME } },
206   { NEARESTMV, { BWDREF_FRAME, NONE_FRAME } },
207   { NEARESTMV, { ALTREF2_FRAME, NONE_FRAME } },
208   { NEARESTMV, { ALTREF_FRAME, NONE_FRAME } },
209   { NEARESTMV, { GOLDEN_FRAME, NONE_FRAME } },
210 
211   { NEWMV, { LAST_FRAME, NONE_FRAME } },
212   { NEWMV, { LAST2_FRAME, NONE_FRAME } },
213   { NEWMV, { LAST3_FRAME, NONE_FRAME } },
214   { NEWMV, { BWDREF_FRAME, NONE_FRAME } },
215   { NEWMV, { ALTREF2_FRAME, NONE_FRAME } },
216   { NEWMV, { ALTREF_FRAME, NONE_FRAME } },
217   { NEWMV, { GOLDEN_FRAME, NONE_FRAME } },
218 
219   { NEARMV, { LAST_FRAME, NONE_FRAME } },
220   { NEARMV, { LAST2_FRAME, NONE_FRAME } },
221   { NEARMV, { LAST3_FRAME, NONE_FRAME } },
222   { NEARMV, { BWDREF_FRAME, NONE_FRAME } },
223   { NEARMV, { ALTREF2_FRAME, NONE_FRAME } },
224   { NEARMV, { ALTREF_FRAME, NONE_FRAME } },
225   { NEARMV, { GOLDEN_FRAME, NONE_FRAME } },
226 
227   { GLOBALMV, { LAST_FRAME, NONE_FRAME } },
228   { GLOBALMV, { LAST2_FRAME, NONE_FRAME } },
229   { GLOBALMV, { LAST3_FRAME, NONE_FRAME } },
230   { GLOBALMV, { BWDREF_FRAME, NONE_FRAME } },
231   { GLOBALMV, { ALTREF2_FRAME, NONE_FRAME } },
232   { GLOBALMV, { GOLDEN_FRAME, NONE_FRAME } },
233   { GLOBALMV, { ALTREF_FRAME, NONE_FRAME } },
234 
235   // TODO(zoeliu): May need to reconsider the order on the modes to check
236 
237   { NEAREST_NEARESTMV, { LAST_FRAME, ALTREF_FRAME } },
238   { NEAREST_NEARESTMV, { LAST2_FRAME, ALTREF_FRAME } },
239   { NEAREST_NEARESTMV, { LAST3_FRAME, ALTREF_FRAME } },
240   { NEAREST_NEARESTMV, { GOLDEN_FRAME, ALTREF_FRAME } },
241   { NEAREST_NEARESTMV, { LAST_FRAME, BWDREF_FRAME } },
242   { NEAREST_NEARESTMV, { LAST2_FRAME, BWDREF_FRAME } },
243   { NEAREST_NEARESTMV, { LAST3_FRAME, BWDREF_FRAME } },
244   { NEAREST_NEARESTMV, { GOLDEN_FRAME, BWDREF_FRAME } },
245   { NEAREST_NEARESTMV, { LAST_FRAME, ALTREF2_FRAME } },
246   { NEAREST_NEARESTMV, { LAST2_FRAME, ALTREF2_FRAME } },
247   { NEAREST_NEARESTMV, { LAST3_FRAME, ALTREF2_FRAME } },
248   { NEAREST_NEARESTMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
249 
250   { NEAREST_NEARESTMV, { LAST_FRAME, LAST2_FRAME } },
251   { NEAREST_NEARESTMV, { LAST_FRAME, LAST3_FRAME } },
252   { NEAREST_NEARESTMV, { LAST_FRAME, GOLDEN_FRAME } },
253   { NEAREST_NEARESTMV, { BWDREF_FRAME, ALTREF_FRAME } },
254 
255   { NEAR_NEARMV, { LAST_FRAME, ALTREF_FRAME } },
256   { NEW_NEARESTMV, { LAST_FRAME, ALTREF_FRAME } },
257   { NEAREST_NEWMV, { LAST_FRAME, ALTREF_FRAME } },
258   { NEW_NEARMV, { LAST_FRAME, ALTREF_FRAME } },
259   { NEAR_NEWMV, { LAST_FRAME, ALTREF_FRAME } },
260   { NEW_NEWMV, { LAST_FRAME, ALTREF_FRAME } },
261   { GLOBAL_GLOBALMV, { LAST_FRAME, ALTREF_FRAME } },
262 
263   { NEAR_NEARMV, { LAST2_FRAME, ALTREF_FRAME } },
264   { NEW_NEARESTMV, { LAST2_FRAME, ALTREF_FRAME } },
265   { NEAREST_NEWMV, { LAST2_FRAME, ALTREF_FRAME } },
266   { NEW_NEARMV, { LAST2_FRAME, ALTREF_FRAME } },
267   { NEAR_NEWMV, { LAST2_FRAME, ALTREF_FRAME } },
268   { NEW_NEWMV, { LAST2_FRAME, ALTREF_FRAME } },
269   { GLOBAL_GLOBALMV, { LAST2_FRAME, ALTREF_FRAME } },
270 
271   { NEAR_NEARMV, { LAST3_FRAME, ALTREF_FRAME } },
272   { NEW_NEARESTMV, { LAST3_FRAME, ALTREF_FRAME } },
273   { NEAREST_NEWMV, { LAST3_FRAME, ALTREF_FRAME } },
274   { NEW_NEARMV, { LAST3_FRAME, ALTREF_FRAME } },
275   { NEAR_NEWMV, { LAST3_FRAME, ALTREF_FRAME } },
276   { NEW_NEWMV, { LAST3_FRAME, ALTREF_FRAME } },
277   { GLOBAL_GLOBALMV, { LAST3_FRAME, ALTREF_FRAME } },
278 
279   { NEAR_NEARMV, { GOLDEN_FRAME, ALTREF_FRAME } },
280   { NEW_NEARESTMV, { GOLDEN_FRAME, ALTREF_FRAME } },
281   { NEAREST_NEWMV, { GOLDEN_FRAME, ALTREF_FRAME } },
282   { NEW_NEARMV, { GOLDEN_FRAME, ALTREF_FRAME } },
283   { NEAR_NEWMV, { GOLDEN_FRAME, ALTREF_FRAME } },
284   { NEW_NEWMV, { GOLDEN_FRAME, ALTREF_FRAME } },
285   { GLOBAL_GLOBALMV, { GOLDEN_FRAME, ALTREF_FRAME } },
286 
287   { NEAR_NEARMV, { LAST_FRAME, BWDREF_FRAME } },
288   { NEW_NEARESTMV, { LAST_FRAME, BWDREF_FRAME } },
289   { NEAREST_NEWMV, { LAST_FRAME, BWDREF_FRAME } },
290   { NEW_NEARMV, { LAST_FRAME, BWDREF_FRAME } },
291   { NEAR_NEWMV, { LAST_FRAME, BWDREF_FRAME } },
292   { NEW_NEWMV, { LAST_FRAME, BWDREF_FRAME } },
293   { GLOBAL_GLOBALMV, { LAST_FRAME, BWDREF_FRAME } },
294 
295   { NEAR_NEARMV, { LAST2_FRAME, BWDREF_FRAME } },
296   { NEW_NEARESTMV, { LAST2_FRAME, BWDREF_FRAME } },
297   { NEAREST_NEWMV, { LAST2_FRAME, BWDREF_FRAME } },
298   { NEW_NEARMV, { LAST2_FRAME, BWDREF_FRAME } },
299   { NEAR_NEWMV, { LAST2_FRAME, BWDREF_FRAME } },
300   { NEW_NEWMV, { LAST2_FRAME, BWDREF_FRAME } },
301   { GLOBAL_GLOBALMV, { LAST2_FRAME, BWDREF_FRAME } },
302 
303   { NEAR_NEARMV, { LAST3_FRAME, BWDREF_FRAME } },
304   { NEW_NEARESTMV, { LAST3_FRAME, BWDREF_FRAME } },
305   { NEAREST_NEWMV, { LAST3_FRAME, BWDREF_FRAME } },
306   { NEW_NEARMV, { LAST3_FRAME, BWDREF_FRAME } },
307   { NEAR_NEWMV, { LAST3_FRAME, BWDREF_FRAME } },
308   { NEW_NEWMV, { LAST3_FRAME, BWDREF_FRAME } },
309   { GLOBAL_GLOBALMV, { LAST3_FRAME, BWDREF_FRAME } },
310 
311   { NEAR_NEARMV, { GOLDEN_FRAME, BWDREF_FRAME } },
312   { NEW_NEARESTMV, { GOLDEN_FRAME, BWDREF_FRAME } },
313   { NEAREST_NEWMV, { GOLDEN_FRAME, BWDREF_FRAME } },
314   { NEW_NEARMV, { GOLDEN_FRAME, BWDREF_FRAME } },
315   { NEAR_NEWMV, { GOLDEN_FRAME, BWDREF_FRAME } },
316   { NEW_NEWMV, { GOLDEN_FRAME, BWDREF_FRAME } },
317   { GLOBAL_GLOBALMV, { GOLDEN_FRAME, BWDREF_FRAME } },
318 
319   { NEAR_NEARMV, { LAST_FRAME, ALTREF2_FRAME } },
320   { NEW_NEARESTMV, { LAST_FRAME, ALTREF2_FRAME } },
321   { NEAREST_NEWMV, { LAST_FRAME, ALTREF2_FRAME } },
322   { NEW_NEARMV, { LAST_FRAME, ALTREF2_FRAME } },
323   { NEAR_NEWMV, { LAST_FRAME, ALTREF2_FRAME } },
324   { NEW_NEWMV, { LAST_FRAME, ALTREF2_FRAME } },
325   { GLOBAL_GLOBALMV, { LAST_FRAME, ALTREF2_FRAME } },
326 
327   { NEAR_NEARMV, { LAST2_FRAME, ALTREF2_FRAME } },
328   { NEW_NEARESTMV, { LAST2_FRAME, ALTREF2_FRAME } },
329   { NEAREST_NEWMV, { LAST2_FRAME, ALTREF2_FRAME } },
330   { NEW_NEARMV, { LAST2_FRAME, ALTREF2_FRAME } },
331   { NEAR_NEWMV, { LAST2_FRAME, ALTREF2_FRAME } },
332   { NEW_NEWMV, { LAST2_FRAME, ALTREF2_FRAME } },
333   { GLOBAL_GLOBALMV, { LAST2_FRAME, ALTREF2_FRAME } },
334 
335   { NEAR_NEARMV, { LAST3_FRAME, ALTREF2_FRAME } },
336   { NEW_NEARESTMV, { LAST3_FRAME, ALTREF2_FRAME } },
337   { NEAREST_NEWMV, { LAST3_FRAME, ALTREF2_FRAME } },
338   { NEW_NEARMV, { LAST3_FRAME, ALTREF2_FRAME } },
339   { NEAR_NEWMV, { LAST3_FRAME, ALTREF2_FRAME } },
340   { NEW_NEWMV, { LAST3_FRAME, ALTREF2_FRAME } },
341   { GLOBAL_GLOBALMV, { LAST3_FRAME, ALTREF2_FRAME } },
342 
343   { NEAR_NEARMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
344   { NEW_NEARESTMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
345   { NEAREST_NEWMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
346   { NEW_NEARMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
347   { NEAR_NEWMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
348   { NEW_NEWMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
349   { GLOBAL_GLOBALMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
350 
351   { NEAR_NEARMV, { LAST_FRAME, LAST2_FRAME } },
352   { NEW_NEARESTMV, { LAST_FRAME, LAST2_FRAME } },
353   { NEAREST_NEWMV, { LAST_FRAME, LAST2_FRAME } },
354   { NEW_NEARMV, { LAST_FRAME, LAST2_FRAME } },
355   { NEAR_NEWMV, { LAST_FRAME, LAST2_FRAME } },
356   { NEW_NEWMV, { LAST_FRAME, LAST2_FRAME } },
357   { GLOBAL_GLOBALMV, { LAST_FRAME, LAST2_FRAME } },
358 
359   { NEAR_NEARMV, { LAST_FRAME, LAST3_FRAME } },
360   { NEW_NEARESTMV, { LAST_FRAME, LAST3_FRAME } },
361   { NEAREST_NEWMV, { LAST_FRAME, LAST3_FRAME } },
362   { NEW_NEARMV, { LAST_FRAME, LAST3_FRAME } },
363   { NEAR_NEWMV, { LAST_FRAME, LAST3_FRAME } },
364   { NEW_NEWMV, { LAST_FRAME, LAST3_FRAME } },
365   { GLOBAL_GLOBALMV, { LAST_FRAME, LAST3_FRAME } },
366 
367   { NEAR_NEARMV, { LAST_FRAME, GOLDEN_FRAME } },
368   { NEW_NEARESTMV, { LAST_FRAME, GOLDEN_FRAME } },
369   { NEAREST_NEWMV, { LAST_FRAME, GOLDEN_FRAME } },
370   { NEW_NEARMV, { LAST_FRAME, GOLDEN_FRAME } },
371   { NEAR_NEWMV, { LAST_FRAME, GOLDEN_FRAME } },
372   { NEW_NEWMV, { LAST_FRAME, GOLDEN_FRAME } },
373   { GLOBAL_GLOBALMV, { LAST_FRAME, GOLDEN_FRAME } },
374 
375   { NEAR_NEARMV, { BWDREF_FRAME, ALTREF_FRAME } },
376   { NEW_NEARESTMV, { BWDREF_FRAME, ALTREF_FRAME } },
377   { NEAREST_NEWMV, { BWDREF_FRAME, ALTREF_FRAME } },
378   { NEW_NEARMV, { BWDREF_FRAME, ALTREF_FRAME } },
379   { NEAR_NEWMV, { BWDREF_FRAME, ALTREF_FRAME } },
380   { NEW_NEWMV, { BWDREF_FRAME, ALTREF_FRAME } },
381   { GLOBAL_GLOBALMV, { BWDREF_FRAME, ALTREF_FRAME } },
382 
383   // intra modes
384   { DC_PRED, { INTRA_FRAME, NONE_FRAME } },
385   { PAETH_PRED, { INTRA_FRAME, NONE_FRAME } },
386   { SMOOTH_PRED, { INTRA_FRAME, NONE_FRAME } },
387   { SMOOTH_V_PRED, { INTRA_FRAME, NONE_FRAME } },
388   { SMOOTH_H_PRED, { INTRA_FRAME, NONE_FRAME } },
389   { H_PRED, { INTRA_FRAME, NONE_FRAME } },
390   { V_PRED, { INTRA_FRAME, NONE_FRAME } },
391   { D135_PRED, { INTRA_FRAME, NONE_FRAME } },
392   { D203_PRED, { INTRA_FRAME, NONE_FRAME } },
393   { D157_PRED, { INTRA_FRAME, NONE_FRAME } },
394   { D67_PRED, { INTRA_FRAME, NONE_FRAME } },
395   { D113_PRED, { INTRA_FRAME, NONE_FRAME } },
396   { D45_PRED, { INTRA_FRAME, NONE_FRAME } },
397 };
398 
399 static const int16_t intra_to_mode_idx[INTRA_MODE_NUM] = {
400   THR_DC,         // DC_PRED,
401   THR_V_PRED,     // V_PRED,
402   THR_H_PRED,     // H_PRED,
403   THR_D45_PRED,   // D45_PRED,
404   THR_D135_PRED,  // D135_PRED,
405   THR_D113_PRED,  // D113_PRED,
406   THR_D157_PRED,  // D157_PRED,
407   THR_D203_PRED,  // D203_PRED,
408   THR_D67_PRED,   // D67_PRED,
409   THR_SMOOTH,     // SMOOTH_PRED,
410   THR_SMOOTH_V,   // SMOOTH_V_PRED,
411   THR_SMOOTH_H,   // SMOOTH_H_PRED,
412   THR_PAETH,      // PAETH_PRED,
413 };
414 
415 /* clang-format off */
416 static const int16_t single_inter_to_mode_idx[SINGLE_INTER_MODE_NUM]
417                                              [REF_FRAMES] = {
418   // NEARESTMV,
419   { -1, THR_NEARESTMV, THR_NEARESTL2, THR_NEARESTL3,
420     THR_NEARESTG, THR_NEARESTB, THR_NEARESTA2, THR_NEARESTA, },
421   // NEARMV,
422   { -1, THR_NEARMV, THR_NEARL2, THR_NEARL3,
423     THR_NEARG, THR_NEARB, THR_NEARA2, THR_NEARA, },
424   // GLOBALMV,
425   { -1, THR_GLOBALMV, THR_GLOBALL2, THR_GLOBALL3,
426     THR_GLOBALG, THR_GLOBALB, THR_GLOBALA2, THR_GLOBALA, },
427   // NEWMV,
428   { -1, THR_NEWMV, THR_NEWL2, THR_NEWL3,
429     THR_NEWG, THR_NEWB, THR_NEWA2, THR_NEWA, },
430 };
431 /* clang-format on */
432 
433 /* clang-format off */
434 static const int16_t comp_inter_to_mode_idx[COMP_INTER_MODE_NUM][REF_FRAMES]
435                                      [REF_FRAMES] = {
436   // NEAREST_NEARESTMV,
437   {
438     { -1, -1, -1, -1, -1, -1, -1, -1, },
439     { -1, -1,
440       THR_COMP_NEAREST_NEARESTLL2, THR_COMP_NEAREST_NEARESTLL3,
441       THR_COMP_NEAREST_NEARESTLG, THR_COMP_NEAREST_NEARESTLB,
442       THR_COMP_NEAREST_NEARESTLA2, THR_COMP_NEAREST_NEARESTLA, },
443     { -1, -1,
444       -1, -1,
445       -1, THR_COMP_NEAREST_NEARESTL2B,
446       THR_COMP_NEAREST_NEARESTL2A2, THR_COMP_NEAREST_NEARESTL2A, },
447     { -1, -1,
448       -1, -1,
449       -1, THR_COMP_NEAREST_NEARESTL3B,
450       THR_COMP_NEAREST_NEARESTL3A2, THR_COMP_NEAREST_NEARESTL3A, },
451     { -1, -1,
452       -1, -1,
453       -1, THR_COMP_NEAREST_NEARESTGB,
454       THR_COMP_NEAREST_NEARESTGA2, THR_COMP_NEAREST_NEARESTGA, },
455     { -1, -1,
456       -1, -1,
457       -1, -1,
458       -1, THR_COMP_NEAREST_NEARESTBA, },
459     { -1, -1, -1, -1, -1, -1, -1, -1, },
460     { -1, -1, -1, -1, -1, -1, -1, -1, },
461   },
462   // NEAR_NEARMV,
463   {
464     { -1, -1, -1, -1, -1, -1, -1, -1, },
465     { -1, -1,
466       THR_COMP_NEAR_NEARLL2, THR_COMP_NEAR_NEARLL3,
467       THR_COMP_NEAR_NEARLG, THR_COMP_NEAR_NEARLB,
468       THR_COMP_NEAR_NEARLA2, THR_COMP_NEAR_NEARLA, },
469     { -1, -1,
470       -1, -1,
471       -1, THR_COMP_NEAR_NEARL2B,
472       THR_COMP_NEAR_NEARL2A2, THR_COMP_NEAR_NEARL2A, },
473     { -1, -1,
474       -1, -1,
475       -1, THR_COMP_NEAR_NEARL3B,
476       THR_COMP_NEAR_NEARL3A2, THR_COMP_NEAR_NEARL3A, },
477     { -1, -1,
478       -1, -1,
479       -1, THR_COMP_NEAR_NEARGB,
480       THR_COMP_NEAR_NEARGA2, THR_COMP_NEAR_NEARGA, },
481     { -1, -1,
482       -1, -1,
483       -1, -1,
484       -1, THR_COMP_NEAR_NEARBA, },
485     { -1, -1, -1, -1, -1, -1, -1, -1, },
486     { -1, -1, -1, -1, -1, -1, -1, -1, },
487   },
488   // NEAREST_NEWMV,
489   {
490     { -1, -1, -1, -1, -1, -1, -1, -1, },
491     { -1, -1,
492       THR_COMP_NEAREST_NEWLL2, THR_COMP_NEAREST_NEWLL3,
493       THR_COMP_NEAREST_NEWLG, THR_COMP_NEAREST_NEWLB,
494       THR_COMP_NEAREST_NEWLA2, THR_COMP_NEAREST_NEWLA, },
495     { -1, -1,
496       -1, -1,
497       -1, THR_COMP_NEAREST_NEWL2B,
498       THR_COMP_NEAREST_NEWL2A2, THR_COMP_NEAREST_NEWL2A, },
499     { -1, -1,
500       -1, -1,
501       -1, THR_COMP_NEAREST_NEWL3B,
502       THR_COMP_NEAREST_NEWL3A2, THR_COMP_NEAREST_NEWL3A, },
503     { -1, -1,
504       -1, -1,
505       -1, THR_COMP_NEAREST_NEWGB,
506       THR_COMP_NEAREST_NEWGA2, THR_COMP_NEAREST_NEWGA, },
507     { -1, -1,
508       -1, -1,
509       -1, -1,
510       -1, THR_COMP_NEAREST_NEWBA, },
511     { -1, -1, -1, -1, -1, -1, -1, -1, },
512     { -1, -1, -1, -1, -1, -1, -1, -1, },
513   },
514   // NEW_NEARESTMV,
515   {
516     { -1, -1, -1, -1, -1, -1, -1, -1, },
517     { -1, -1,
518       THR_COMP_NEW_NEARESTLL2, THR_COMP_NEW_NEARESTLL3,
519       THR_COMP_NEW_NEARESTLG, THR_COMP_NEW_NEARESTLB,
520       THR_COMP_NEW_NEARESTLA2, THR_COMP_NEW_NEARESTLA, },
521     { -1, -1,
522       -1, -1,
523       -1, THR_COMP_NEW_NEARESTL2B,
524       THR_COMP_NEW_NEARESTL2A2, THR_COMP_NEW_NEARESTL2A, },
525     { -1, -1,
526       -1, -1,
527       -1, THR_COMP_NEW_NEARESTL3B,
528       THR_COMP_NEW_NEARESTL3A2, THR_COMP_NEW_NEARESTL3A, },
529     { -1, -1,
530       -1, -1,
531       -1, THR_COMP_NEW_NEARESTGB,
532       THR_COMP_NEW_NEARESTGA2, THR_COMP_NEW_NEARESTGA, },
533     { -1, -1,
534       -1, -1,
535       -1, -1,
536       -1, THR_COMP_NEW_NEARESTBA, },
537     { -1, -1, -1, -1, -1, -1, -1, -1, },
538     { -1, -1, -1, -1, -1, -1, -1, -1, },
539   },
540   // NEAR_NEWMV,
541   {
542     { -1, -1, -1, -1, -1, -1, -1, -1, },
543     { -1, -1,
544       THR_COMP_NEAR_NEWLL2, THR_COMP_NEAR_NEWLL3,
545       THR_COMP_NEAR_NEWLG, THR_COMP_NEAR_NEWLB,
546       THR_COMP_NEAR_NEWLA2, THR_COMP_NEAR_NEWLA, },
547     { -1, -1,
548       -1, -1,
549       -1, THR_COMP_NEAR_NEWL2B,
550       THR_COMP_NEAR_NEWL2A2, THR_COMP_NEAR_NEWL2A, },
551     { -1, -1,
552       -1, -1,
553       -1, THR_COMP_NEAR_NEWL3B,
554       THR_COMP_NEAR_NEWL3A2, THR_COMP_NEAR_NEWL3A, },
555     { -1, -1,
556       -1, -1,
557       -1, THR_COMP_NEAR_NEWGB,
558       THR_COMP_NEAR_NEWGA2, THR_COMP_NEAR_NEWGA, },
559     { -1, -1,
560       -1, -1,
561       -1, -1,
562       -1, THR_COMP_NEAR_NEWBA, },
563     { -1, -1, -1, -1, -1, -1, -1, -1, },
564     { -1, -1, -1, -1, -1, -1, -1, -1, },
565   },
566   // NEW_NEARMV,
567   {
568     { -1, -1, -1, -1, -1, -1, -1, -1, },
569     { -1, -1,
570       THR_COMP_NEW_NEARLL2, THR_COMP_NEW_NEARLL3,
571       THR_COMP_NEW_NEARLG, THR_COMP_NEW_NEARLB,
572       THR_COMP_NEW_NEARLA2, THR_COMP_NEW_NEARLA, },
573     { -1, -1,
574       -1, -1,
575       -1, THR_COMP_NEW_NEARL2B,
576       THR_COMP_NEW_NEARL2A2, THR_COMP_NEW_NEARL2A, },
577     { -1, -1,
578       -1, -1,
579       -1, THR_COMP_NEW_NEARL3B,
580       THR_COMP_NEW_NEARL3A2, THR_COMP_NEW_NEARL3A, },
581     { -1, -1,
582       -1, -1,
583       -1, THR_COMP_NEW_NEARGB,
584       THR_COMP_NEW_NEARGA2, THR_COMP_NEW_NEARGA, },
585     { -1, -1,
586       -1, -1,
587       -1, -1,
588       -1, THR_COMP_NEW_NEARBA, },
589     { -1, -1, -1, -1, -1, -1, -1, -1, },
590     { -1, -1, -1, -1, -1, -1, -1, -1, },
591   },
592   // GLOBAL_GLOBALMV,
593   {
594     { -1, -1, -1, -1, -1, -1, -1, -1, },
595     { -1, -1,
596       THR_COMP_GLOBAL_GLOBALLL2, THR_COMP_GLOBAL_GLOBALLL3,
597       THR_COMP_GLOBAL_GLOBALLG, THR_COMP_GLOBAL_GLOBALLB,
598       THR_COMP_GLOBAL_GLOBALLA2, THR_COMP_GLOBAL_GLOBALLA, },
599     { -1, -1,
600       -1, -1,
601       -1, THR_COMP_GLOBAL_GLOBALL2B,
602       THR_COMP_GLOBAL_GLOBALL2A2, THR_COMP_GLOBAL_GLOBALL2A, },
603     { -1, -1,
604       -1, -1,
605       -1, THR_COMP_GLOBAL_GLOBALL3B,
606       THR_COMP_GLOBAL_GLOBALL3A2, THR_COMP_GLOBAL_GLOBALL3A, },
607     { -1, -1,
608       -1, -1,
609       -1, THR_COMP_GLOBAL_GLOBALGB,
610       THR_COMP_GLOBAL_GLOBALGA2, THR_COMP_GLOBAL_GLOBALGA, },
611     { -1, -1,
612       -1, -1,
613       -1, -1,
614       -1, THR_COMP_GLOBAL_GLOBALBA, },
615     { -1, -1, -1, -1, -1, -1, -1, -1, },
616     { -1, -1, -1, -1, -1, -1, -1, -1, },
617   },
618   // NEW_NEWMV,
619   {
620     { -1, -1, -1, -1, -1, -1, -1, -1, },
621     { -1, -1,
622       THR_COMP_NEW_NEWLL2, THR_COMP_NEW_NEWLL3,
623       THR_COMP_NEW_NEWLG, THR_COMP_NEW_NEWLB,
624       THR_COMP_NEW_NEWLA2, THR_COMP_NEW_NEWLA, },
625     { -1, -1,
626       -1, -1,
627       -1, THR_COMP_NEW_NEWL2B,
628       THR_COMP_NEW_NEWL2A2, THR_COMP_NEW_NEWL2A, },
629     { -1, -1,
630       -1, -1,
631       -1, THR_COMP_NEW_NEWL3B,
632       THR_COMP_NEW_NEWL3A2, THR_COMP_NEW_NEWL3A, },
633     { -1, -1,
634       -1, -1,
635       -1, THR_COMP_NEW_NEWGB,
636       THR_COMP_NEW_NEWGA2, THR_COMP_NEW_NEWGA, },
637     { -1, -1,
638       -1, -1,
639       -1, -1,
640       -1, THR_COMP_NEW_NEWBA, },
641     { -1, -1, -1, -1, -1, -1, -1, -1, },
642     { -1, -1, -1, -1, -1, -1, -1, -1, },
643   },
644 };
645 /* clang-format on */
646 
get_prediction_mode_idx(PREDICTION_MODE this_mode,MV_REFERENCE_FRAME ref_frame,MV_REFERENCE_FRAME second_ref_frame)647 static int get_prediction_mode_idx(PREDICTION_MODE this_mode,
648                                    MV_REFERENCE_FRAME ref_frame,
649                                    MV_REFERENCE_FRAME second_ref_frame) {
650   if (this_mode < INTRA_MODE_END) {
651     assert(ref_frame == INTRA_FRAME);
652     assert(second_ref_frame == NONE_FRAME);
653     return intra_to_mode_idx[this_mode - INTRA_MODE_START];
654   }
655   if (this_mode >= SINGLE_INTER_MODE_START &&
656       this_mode < SINGLE_INTER_MODE_END) {
657     assert((ref_frame > INTRA_FRAME) && (ref_frame <= ALTREF_FRAME));
658     return single_inter_to_mode_idx[this_mode - SINGLE_INTER_MODE_START]
659                                    [ref_frame];
660   }
661   if (this_mode >= COMP_INTER_MODE_START && this_mode < COMP_INTER_MODE_END) {
662     assert((ref_frame > INTRA_FRAME) && (ref_frame <= ALTREF_FRAME));
663     assert((second_ref_frame > INTRA_FRAME) &&
664            (second_ref_frame <= ALTREF_FRAME));
665     return comp_inter_to_mode_idx[this_mode - COMP_INTER_MODE_START][ref_frame]
666                                  [second_ref_frame];
667   }
668   assert(0);
669   return -1;
670 }
671 
672 static const PREDICTION_MODE intra_rd_search_mode_order[INTRA_MODES] = {
673   DC_PRED,       H_PRED,        V_PRED,    SMOOTH_PRED, PAETH_PRED,
674   SMOOTH_V_PRED, SMOOTH_H_PRED, D135_PRED, D203_PRED,   D157_PRED,
675   D67_PRED,      D113_PRED,     D45_PRED,
676 };
677 
678 static const UV_PREDICTION_MODE uv_rd_search_mode_order[UV_INTRA_MODES] = {
679   UV_DC_PRED,     UV_CFL_PRED,   UV_H_PRED,        UV_V_PRED,
680   UV_SMOOTH_PRED, UV_PAETH_PRED, UV_SMOOTH_V_PRED, UV_SMOOTH_H_PRED,
681   UV_D135_PRED,   UV_D203_PRED,  UV_D157_PRED,     UV_D67_PRED,
682   UV_D113_PRED,   UV_D45_PRED,
683 };
684 
685 typedef struct SingleInterModeState {
686   int64_t rd;
687   MV_REFERENCE_FRAME ref_frame;
688   int valid;
689 } SingleInterModeState;
690 
691 typedef struct InterModeSearchState {
692   int64_t best_rd;
693   MB_MODE_INFO best_mbmode;
694   int best_rate_y;
695   int best_rate_uv;
696   int best_mode_skippable;
697   int best_skip2;
698   int best_mode_index;
699   int skip_intra_modes;
700   int num_available_refs;
701   int64_t dist_refs[REF_FRAMES];
702   int dist_order_refs[REF_FRAMES];
703   int64_t mode_threshold[MAX_MODES];
704   PREDICTION_MODE best_intra_mode;
705   int64_t best_intra_rd;
706   int angle_stats_ready;
707   uint8_t directional_mode_skip_mask[INTRA_MODES];
708   unsigned int best_pred_sse;
709   int rate_uv_intra[TX_SIZES_ALL];
710   int rate_uv_tokenonly[TX_SIZES_ALL];
711   int64_t dist_uvs[TX_SIZES_ALL];
712   int skip_uvs[TX_SIZES_ALL];
713   UV_PREDICTION_MODE mode_uv[TX_SIZES_ALL];
714   PALETTE_MODE_INFO pmi_uv[TX_SIZES_ALL];
715   int8_t uv_angle_delta[TX_SIZES_ALL];
716   int64_t best_pred_rd[REFERENCE_MODES];
717   int64_t best_pred_diff[REFERENCE_MODES];
718   // Save a set of single_newmv for each checked ref_mv.
719   int_mv single_newmv[MAX_REF_MV_SERCH][REF_FRAMES];
720   int single_newmv_rate[MAX_REF_MV_SERCH][REF_FRAMES];
721   int single_newmv_valid[MAX_REF_MV_SERCH][REF_FRAMES];
722   int64_t modelled_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
723   // The rd of simple translation in single inter modes
724   int64_t simple_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
725 
726   // Single search results by [directions][modes][reference frames]
727   SingleInterModeState single_state[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
728   int single_state_cnt[2][SINGLE_INTER_MODE_NUM];
729   SingleInterModeState single_state_modelled[2][SINGLE_INTER_MODE_NUM]
730                                             [FWD_REFS];
731   int single_state_modelled_cnt[2][SINGLE_INTER_MODE_NUM];
732 
733   MV_REFERENCE_FRAME single_rd_order[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
734 } InterModeSearchState;
735 
inter_mode_data_block_idx(BLOCK_SIZE bsize)736 static int inter_mode_data_block_idx(BLOCK_SIZE bsize) {
737   if (bsize == BLOCK_4X4 || bsize == BLOCK_4X8 || bsize == BLOCK_8X4 ||
738       bsize == BLOCK_4X16 || bsize == BLOCK_16X4) {
739     return -1;
740   }
741   return 1;
742 }
743 
av1_inter_mode_data_init(TileDataEnc * tile_data)744 void av1_inter_mode_data_init(TileDataEnc *tile_data) {
745   for (int i = 0; i < BLOCK_SIZES_ALL; ++i) {
746     InterModeRdModel *md = &tile_data->inter_mode_rd_models[i];
747     md->ready = 0;
748     md->num = 0;
749     md->dist_sum = 0;
750     md->ld_sum = 0;
751     md->sse_sum = 0;
752     md->sse_sse_sum = 0;
753     md->sse_ld_sum = 0;
754   }
755 }
756 
get_est_rate_dist(const TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)757 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
758                              int64_t sse, int *est_residue_cost,
759                              int64_t *est_dist) {
760   aom_clear_system_state();
761   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
762   if (md->ready) {
763     if (sse < md->dist_mean) {
764       *est_residue_cost = 0;
765       *est_dist = sse;
766     } else {
767       *est_dist = (int64_t)round(md->dist_mean);
768       const double est_ld = md->a * sse + md->b;
769       // Clamp estimated rate cost by INT_MAX / 2.
770       // TODO(angiebird@google.com): find better solution than clamping.
771       if (fabs(est_ld) < 1e-2) {
772         *est_residue_cost = INT_MAX / 2;
773       } else {
774         double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
775         if (est_residue_cost_dbl < 0) {
776           *est_residue_cost = 0;
777         } else {
778           *est_residue_cost =
779               (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
780         }
781       }
782       if (*est_residue_cost <= 0) {
783         *est_residue_cost = 0;
784         *est_dist = sse;
785       }
786     }
787     return 1;
788   }
789   return 0;
790 }
791 
av1_inter_mode_data_fit(TileDataEnc * tile_data,int rdmult)792 void av1_inter_mode_data_fit(TileDataEnc *tile_data, int rdmult) {
793   aom_clear_system_state();
794   for (int bsize = 0; bsize < BLOCK_SIZES_ALL; ++bsize) {
795     const int block_idx = inter_mode_data_block_idx(bsize);
796     InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
797     if (block_idx == -1) continue;
798     if ((md->ready == 0 && md->num < 200) || (md->ready == 1 && md->num < 64)) {
799       continue;
800     } else {
801       if (md->ready == 0) {
802         md->dist_mean = md->dist_sum / md->num;
803         md->ld_mean = md->ld_sum / md->num;
804         md->sse_mean = md->sse_sum / md->num;
805         md->sse_sse_mean = md->sse_sse_sum / md->num;
806         md->sse_ld_mean = md->sse_ld_sum / md->num;
807       } else {
808         const double factor = 3;
809         md->dist_mean =
810             (md->dist_mean * factor + (md->dist_sum / md->num)) / (factor + 1);
811         md->ld_mean =
812             (md->ld_mean * factor + (md->ld_sum / md->num)) / (factor + 1);
813         md->sse_mean =
814             (md->sse_mean * factor + (md->sse_sum / md->num)) / (factor + 1);
815         md->sse_sse_mean =
816             (md->sse_sse_mean * factor + (md->sse_sse_sum / md->num)) /
817             (factor + 1);
818         md->sse_ld_mean =
819             (md->sse_ld_mean * factor + (md->sse_ld_sum / md->num)) /
820             (factor + 1);
821       }
822 
823       const double my = md->ld_mean;
824       const double mx = md->sse_mean;
825       const double dx = sqrt(md->sse_sse_mean);
826       const double dxy = md->sse_ld_mean;
827 
828       md->a = (dxy - mx * my) / (dx * dx - mx * mx);
829       md->b = my - md->a * mx;
830       md->ready = 1;
831 
832       md->num = 0;
833       md->dist_sum = 0;
834       md->ld_sum = 0;
835       md->sse_sum = 0;
836       md->sse_sse_sum = 0;
837       md->sse_ld_sum = 0;
838     }
839     (void)rdmult;
840   }
841 }
842 
inter_mode_data_push(TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int64_t dist,int residue_cost)843 static void inter_mode_data_push(TileDataEnc *tile_data, BLOCK_SIZE bsize,
844                                  int64_t sse, int64_t dist, int residue_cost) {
845   if (residue_cost == 0 || sse == dist) return;
846   const int block_idx = inter_mode_data_block_idx(bsize);
847   if (block_idx == -1) return;
848   InterModeRdModel *rd_model = &tile_data->inter_mode_rd_models[bsize];
849   if (rd_model->num < INTER_MODE_RD_DATA_OVERALL_SIZE) {
850     aom_clear_system_state();
851     const double ld = (sse - dist) * 1. / residue_cost;
852     ++rd_model->num;
853     rd_model->dist_sum += dist;
854     rd_model->ld_sum += ld;
855     rd_model->sse_sum += sse;
856     rd_model->sse_sse_sum += (double)sse * (double)sse;
857     rd_model->sse_ld_sum += sse * ld;
858   }
859 }
860 
inter_modes_info_push(InterModesInfo * inter_modes_info,int mode_rate,int64_t sse,int64_t rd,bool true_rd,uint8_t * blk_skip,RD_STATS * rd_cost,RD_STATS * rd_cost_y,RD_STATS * rd_cost_uv,const MB_MODE_INFO * mbmi)861 static void inter_modes_info_push(InterModesInfo *inter_modes_info,
862                                   int mode_rate, int64_t sse, int64_t rd,
863                                   bool true_rd, uint8_t *blk_skip,
864                                   RD_STATS *rd_cost, RD_STATS *rd_cost_y,
865                                   RD_STATS *rd_cost_uv,
866                                   const MB_MODE_INFO *mbmi) {
867   const int num = inter_modes_info->num;
868   assert(num < MAX_INTER_MODES);
869   inter_modes_info->mbmi_arr[num] = *mbmi;
870   inter_modes_info->mode_rate_arr[num] = mode_rate;
871   inter_modes_info->sse_arr[num] = sse;
872   inter_modes_info->est_rd_arr[num] = rd;
873   inter_modes_info->true_rd_arr[num] = true_rd;
874   if (blk_skip != NULL) {
875     memcpy(inter_modes_info->blk_skip_arr[num], blk_skip,
876            sizeof(blk_skip[0]) * MAX_MIB_SIZE * MAX_MIB_SIZE);
877   }
878   inter_modes_info->rd_cost_arr[num] = *rd_cost;
879   inter_modes_info->rd_cost_y_arr[num] = *rd_cost_y;
880   inter_modes_info->rd_cost_uv_arr[num] = *rd_cost_uv;
881   ++inter_modes_info->num;
882 }
883 
compare_rd_idx_pair(const void * a,const void * b)884 static int compare_rd_idx_pair(const void *a, const void *b) {
885   if (((RdIdxPair *)a)->rd == ((RdIdxPair *)b)->rd) {
886     return 0;
887   } else if (((const RdIdxPair *)a)->rd > ((const RdIdxPair *)b)->rd) {
888     return 1;
889   } else {
890     return -1;
891   }
892 }
893 
inter_modes_info_sort(const InterModesInfo * inter_modes_info,RdIdxPair * rd_idx_pair_arr)894 static void inter_modes_info_sort(const InterModesInfo *inter_modes_info,
895                                   RdIdxPair *rd_idx_pair_arr) {
896   if (inter_modes_info->num == 0) {
897     return;
898   }
899   for (int i = 0; i < inter_modes_info->num; ++i) {
900     rd_idx_pair_arr[i].idx = i;
901     rd_idx_pair_arr[i].rd = inter_modes_info->est_rd_arr[i];
902   }
903   qsort(rd_idx_pair_arr, inter_modes_info->num, sizeof(rd_idx_pair_arr[0]),
904         compare_rd_idx_pair);
905 }
906 
write_uniform_cost(int n,int v)907 static INLINE int write_uniform_cost(int n, int v) {
908   const int l = get_unsigned_bits(n);
909   const int m = (1 << l) - n;
910   if (l == 0) return 0;
911   if (v < m)
912     return av1_cost_literal(l - 1);
913   else
914     return av1_cost_literal(l);
915 }
916 
917 // Similar to store_cfl_required(), but for use during the RDO process,
918 // where we haven't yet determined whether this block uses CfL.
store_cfl_required_rdo(const AV1_COMMON * cm,const MACROBLOCK * x)919 static INLINE CFL_ALLOWED_TYPE store_cfl_required_rdo(const AV1_COMMON *cm,
920                                                       const MACROBLOCK *x) {
921   const MACROBLOCKD *xd = &x->e_mbd;
922 
923   if (cm->seq_params.monochrome || x->skip_chroma_rd) return CFL_DISALLOWED;
924 
925   if (!xd->cfl.is_chroma_reference) {
926     // For non-chroma-reference blocks, we should always store the luma pixels,
927     // in case the corresponding chroma-reference block uses CfL.
928     // Note that this can only happen for block sizes which are <8 on
929     // their shortest side, as otherwise they would be chroma reference
930     // blocks.
931     return CFL_ALLOWED;
932   }
933 
934   // For chroma reference blocks, we should store data in the encoder iff we're
935   // allowed to try out CfL.
936   return is_cfl_allowed(xd);
937 }
938 
939 // constants for prune 1 and prune 2 decision boundaries
940 #define FAST_EXT_TX_CORR_MID 0.0
941 #define FAST_EXT_TX_EDST_MID 0.1
942 #define FAST_EXT_TX_CORR_MARGIN 0.5
943 #define FAST_EXT_TX_EDST_MARGIN 0.3
944 
945 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
946                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
947                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode);
948 
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)949 static unsigned pixel_dist_visible_only(
950     const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
951     const int src_stride, const uint8_t *dst, const int dst_stride,
952     const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
953     int visible_cols) {
954   unsigned sse;
955 
956   if (txb_rows == visible_rows && txb_cols == visible_cols) {
957     cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
958     return sse;
959   }
960   const MACROBLOCKD *xd = &x->e_mbd;
961 
962   if (is_cur_buf_hbd(xd)) {
963     uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
964                                              visible_cols, visible_rows);
965     return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
966   }
967   sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
968                          visible_rows);
969   return sse;
970 }
971 
972 #if CONFIG_DIST_8X8
cdef_dist_8x8_16bit(uint16_t * dst,int dstride,uint16_t * src,int sstride,int coeff_shift)973 static uint64_t cdef_dist_8x8_16bit(uint16_t *dst, int dstride, uint16_t *src,
974                                     int sstride, int coeff_shift) {
975   uint64_t svar = 0;
976   uint64_t dvar = 0;
977   uint64_t sum_s = 0;
978   uint64_t sum_d = 0;
979   uint64_t sum_s2 = 0;
980   uint64_t sum_d2 = 0;
981   uint64_t sum_sd = 0;
982   uint64_t dist = 0;
983 
984   int i, j;
985   for (i = 0; i < 8; i++) {
986     for (j = 0; j < 8; j++) {
987       sum_s += src[i * sstride + j];
988       sum_d += dst[i * dstride + j];
989       sum_s2 += src[i * sstride + j] * src[i * sstride + j];
990       sum_d2 += dst[i * dstride + j] * dst[i * dstride + j];
991       sum_sd += src[i * sstride + j] * dst[i * dstride + j];
992     }
993   }
994   /* Compute the variance -- the calculation cannot go negative. */
995   svar = sum_s2 - ((sum_s * sum_s + 32) >> 6);
996   dvar = sum_d2 - ((sum_d * sum_d + 32) >> 6);
997 
998   // Tuning of jm's original dering distortion metric used in CDEF tool,
999   // suggested by jm
1000   const uint64_t a = 4;
1001   const uint64_t b = 2;
1002   const uint64_t c1 = (400 * a << 2 * coeff_shift);
1003   const uint64_t c2 = (b * 20000 * a * a << 4 * coeff_shift);
1004 
1005   dist = (uint64_t)floor(.5 + (sum_d2 + sum_s2 - 2 * sum_sd) * .5 *
1006                                   (svar + dvar + c1) /
1007                                   (sqrt(svar * (double)dvar + c2)));
1008 
1009   // Calibrate dist to have similar rate for the same QP with MSE only
1010   // distortion (as in master branch)
1011   dist = (uint64_t)((float)dist * 0.75);
1012 
1013   return dist;
1014 }
1015 
od_compute_var_4x4(uint16_t * x,int stride)1016 static int od_compute_var_4x4(uint16_t *x, int stride) {
1017   int sum;
1018   int s2;
1019   int i;
1020   sum = 0;
1021   s2 = 0;
1022   for (i = 0; i < 4; i++) {
1023     int j;
1024     for (j = 0; j < 4; j++) {
1025       int t;
1026 
1027       t = x[i * stride + j];
1028       sum += t;
1029       s2 += t * t;
1030     }
1031   }
1032 
1033   return (s2 - (sum * sum >> 4)) >> 4;
1034 }
1035 
1036 /* OD_DIST_LP_MID controls the frequency weighting filter used for computing
1037    the distortion. For a value X, the filter is [1 X 1]/(X + 2) and
1038    is applied both horizontally and vertically. For X=5, the filter is
1039    a good approximation for the OD_QM8_Q4_HVS quantization matrix. */
1040 #define OD_DIST_LP_MID (5)
1041 #define OD_DIST_LP_NORM (OD_DIST_LP_MID + 2)
1042 
od_compute_dist_8x8(int use_activity_masking,uint16_t * x,uint16_t * y,od_coeff * e_lp,int stride)1043 static double od_compute_dist_8x8(int use_activity_masking, uint16_t *x,
1044                                   uint16_t *y, od_coeff *e_lp, int stride) {
1045   double sum;
1046   int min_var;
1047   double mean_var;
1048   double var_stat;
1049   double activity;
1050   double calibration;
1051   int i;
1052   int j;
1053   double vardist;
1054 
1055   vardist = 0;
1056 
1057 #if 1
1058   min_var = INT_MAX;
1059   mean_var = 0;
1060   for (i = 0; i < 3; i++) {
1061     for (j = 0; j < 3; j++) {
1062       int varx;
1063       int vary;
1064       varx = od_compute_var_4x4(x + 2 * i * stride + 2 * j, stride);
1065       vary = od_compute_var_4x4(y + 2 * i * stride + 2 * j, stride);
1066       min_var = OD_MINI(min_var, varx);
1067       mean_var += 1. / (1 + varx);
1068       /* The cast to (double) is to avoid an overflow before the sqrt.*/
1069       vardist += varx - 2 * sqrt(varx * (double)vary) + vary;
1070     }
1071   }
1072   /* We use a different variance statistic depending on whether activity
1073      masking is used, since the harmonic mean appeared slightly worse with
1074      masking off. The calibration constant just ensures that we preserve the
1075      rate compared to activity=1. */
1076   if (use_activity_masking) {
1077     calibration = 1.95;
1078     var_stat = 9. / mean_var;
1079   } else {
1080     calibration = 1.62;
1081     var_stat = min_var;
1082   }
1083   /* 1.62 is a calibration constant, 0.25 is a noise floor and 1/6 is the
1084      activity masking constant. */
1085   activity = calibration * pow(.25 + var_stat, -1. / 6);
1086 #else
1087   activity = 1;
1088 #endif  // 1
1089   sum = 0;
1090   for (i = 0; i < 8; i++) {
1091     for (j = 0; j < 8; j++)
1092       sum += e_lp[i * stride + j] * (double)e_lp[i * stride + j];
1093   }
1094   /* Normalize the filter to unit DC response. */
1095   sum *= 1. / (OD_DIST_LP_NORM * OD_DIST_LP_NORM * OD_DIST_LP_NORM *
1096                OD_DIST_LP_NORM);
1097   return activity * activity * (sum + vardist);
1098 }
1099 
1100 // Note : Inputs x and y are in a pixel domain
od_compute_dist_common(int activity_masking,uint16_t * x,uint16_t * y,int bsize_w,int bsize_h,int qindex,od_coeff * tmp,od_coeff * e_lp)1101 static double od_compute_dist_common(int activity_masking, uint16_t *x,
1102                                      uint16_t *y, int bsize_w, int bsize_h,
1103                                      int qindex, od_coeff *tmp,
1104                                      od_coeff *e_lp) {
1105   int i, j;
1106   double sum = 0;
1107   const int mid = OD_DIST_LP_MID;
1108 
1109   for (j = 0; j < bsize_w; j++) {
1110     e_lp[j] = mid * tmp[j] + 2 * tmp[bsize_w + j];
1111     e_lp[(bsize_h - 1) * bsize_w + j] = mid * tmp[(bsize_h - 1) * bsize_w + j] +
1112                                         2 * tmp[(bsize_h - 2) * bsize_w + j];
1113   }
1114   for (i = 1; i < bsize_h - 1; i++) {
1115     for (j = 0; j < bsize_w; j++) {
1116       e_lp[i * bsize_w + j] = mid * tmp[i * bsize_w + j] +
1117                               tmp[(i - 1) * bsize_w + j] +
1118                               tmp[(i + 1) * bsize_w + j];
1119     }
1120   }
1121   for (i = 0; i < bsize_h; i += 8) {
1122     for (j = 0; j < bsize_w; j += 8) {
1123       sum += od_compute_dist_8x8(activity_masking, &x[i * bsize_w + j],
1124                                  &y[i * bsize_w + j], &e_lp[i * bsize_w + j],
1125                                  bsize_w);
1126     }
1127   }
1128   /* Scale according to linear regression against SSE, for 8x8 blocks. */
1129   if (activity_masking) {
1130     sum *= 2.2 + (1.7 - 2.2) * (qindex - 99) / (210 - 99) +
1131            (qindex < 99 ? 2.5 * (qindex - 99) / 99 * (qindex - 99) / 99 : 0);
1132   } else {
1133     sum *= qindex >= 128
1134                ? 1.4 + (0.9 - 1.4) * (qindex - 128) / (209 - 128)
1135                : qindex <= 43 ? 1.5 + (2.0 - 1.5) * (qindex - 43) / (16 - 43)
1136                               : 1.5 + (1.4 - 1.5) * (qindex - 43) / (128 - 43);
1137   }
1138 
1139   return sum;
1140 }
1141 
od_compute_dist(uint16_t * x,uint16_t * y,int bsize_w,int bsize_h,int qindex)1142 static double od_compute_dist(uint16_t *x, uint16_t *y, int bsize_w,
1143                               int bsize_h, int qindex) {
1144   assert(bsize_w >= 8 && bsize_h >= 8);
1145 
1146   int activity_masking = 0;
1147 
1148   int i, j;
1149   DECLARE_ALIGNED(16, od_coeff, e[MAX_SB_SQUARE]);
1150   DECLARE_ALIGNED(16, od_coeff, tmp[MAX_SB_SQUARE]);
1151   DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_SB_SQUARE]);
1152   for (i = 0; i < bsize_h; i++) {
1153     for (j = 0; j < bsize_w; j++) {
1154       e[i * bsize_w + j] = x[i * bsize_w + j] - y[i * bsize_w + j];
1155     }
1156   }
1157   int mid = OD_DIST_LP_MID;
1158   for (i = 0; i < bsize_h; i++) {
1159     tmp[i * bsize_w] = mid * e[i * bsize_w] + 2 * e[i * bsize_w + 1];
1160     tmp[i * bsize_w + bsize_w - 1] =
1161         mid * e[i * bsize_w + bsize_w - 1] + 2 * e[i * bsize_w + bsize_w - 2];
1162     for (j = 1; j < bsize_w - 1; j++) {
1163       tmp[i * bsize_w + j] = mid * e[i * bsize_w + j] + e[i * bsize_w + j - 1] +
1164                              e[i * bsize_w + j + 1];
1165     }
1166   }
1167   return od_compute_dist_common(activity_masking, x, y, bsize_w, bsize_h,
1168                                 qindex, tmp, e_lp);
1169 }
1170 
od_compute_dist_diff(uint16_t * x,int16_t * e,int bsize_w,int bsize_h,int qindex)1171 static double od_compute_dist_diff(uint16_t *x, int16_t *e, int bsize_w,
1172                                    int bsize_h, int qindex) {
1173   assert(bsize_w >= 8 && bsize_h >= 8);
1174 
1175   int activity_masking = 0;
1176 
1177   DECLARE_ALIGNED(16, uint16_t, y[MAX_SB_SQUARE]);
1178   DECLARE_ALIGNED(16, od_coeff, tmp[MAX_SB_SQUARE]);
1179   DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_SB_SQUARE]);
1180   int i, j;
1181   for (i = 0; i < bsize_h; i++) {
1182     for (j = 0; j < bsize_w; j++) {
1183       y[i * bsize_w + j] = x[i * bsize_w + j] - e[i * bsize_w + j];
1184     }
1185   }
1186   int mid = OD_DIST_LP_MID;
1187   for (i = 0; i < bsize_h; i++) {
1188     tmp[i * bsize_w] = mid * e[i * bsize_w] + 2 * e[i * bsize_w + 1];
1189     tmp[i * bsize_w + bsize_w - 1] =
1190         mid * e[i * bsize_w + bsize_w - 1] + 2 * e[i * bsize_w + bsize_w - 2];
1191     for (j = 1; j < bsize_w - 1; j++) {
1192       tmp[i * bsize_w + j] = mid * e[i * bsize_w + j] + e[i * bsize_w + j - 1] +
1193                              e[i * bsize_w + j + 1];
1194     }
1195   }
1196   return od_compute_dist_common(activity_masking, x, y, bsize_w, bsize_h,
1197                                 qindex, tmp, e_lp);
1198 }
1199 
av1_dist_8x8(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const BLOCK_SIZE tx_bsize,int bsw,int bsh,int visible_w,int visible_h,int qindex)1200 int64_t av1_dist_8x8(const AV1_COMP *const cpi, const MACROBLOCK *x,
1201                      const uint8_t *src, int src_stride, const uint8_t *dst,
1202                      int dst_stride, const BLOCK_SIZE tx_bsize, int bsw,
1203                      int bsh, int visible_w, int visible_h, int qindex) {
1204   int64_t d = 0;
1205   int i, j;
1206   const MACROBLOCKD *xd = &x->e_mbd;
1207 
1208   DECLARE_ALIGNED(16, uint16_t, orig[MAX_SB_SQUARE]);
1209   DECLARE_ALIGNED(16, uint16_t, rec[MAX_SB_SQUARE]);
1210 
1211   assert(bsw >= 8);
1212   assert(bsh >= 8);
1213   assert((bsw & 0x07) == 0);
1214   assert((bsh & 0x07) == 0);
1215 
1216   if (x->tune_metric == AOM_TUNE_CDEF_DIST ||
1217       x->tune_metric == AOM_TUNE_DAALA_DIST) {
1218     if (is_cur_buf_hbd(xd)) {
1219       for (j = 0; j < bsh; j++)
1220         for (i = 0; i < bsw; i++)
1221           orig[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1222 
1223       if ((bsw == visible_w) && (bsh == visible_h)) {
1224         for (j = 0; j < bsh; j++)
1225           for (i = 0; i < bsw; i++)
1226             rec[j * bsw + i] = CONVERT_TO_SHORTPTR(dst)[j * dst_stride + i];
1227       } else {
1228         for (j = 0; j < visible_h; j++)
1229           for (i = 0; i < visible_w; i++)
1230             rec[j * bsw + i] = CONVERT_TO_SHORTPTR(dst)[j * dst_stride + i];
1231 
1232         if (visible_w < bsw) {
1233           for (j = 0; j < bsh; j++)
1234             for (i = visible_w; i < bsw; i++)
1235               rec[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1236         }
1237 
1238         if (visible_h < bsh) {
1239           for (j = visible_h; j < bsh; j++)
1240             for (i = 0; i < bsw; i++)
1241               rec[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1242         }
1243       }
1244     } else {
1245       for (j = 0; j < bsh; j++)
1246         for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
1247 
1248       if ((bsw == visible_w) && (bsh == visible_h)) {
1249         for (j = 0; j < bsh; j++)
1250           for (i = 0; i < bsw; i++) rec[j * bsw + i] = dst[j * dst_stride + i];
1251       } else {
1252         for (j = 0; j < visible_h; j++)
1253           for (i = 0; i < visible_w; i++)
1254             rec[j * bsw + i] = dst[j * dst_stride + i];
1255 
1256         if (visible_w < bsw) {
1257           for (j = 0; j < bsh; j++)
1258             for (i = visible_w; i < bsw; i++)
1259               rec[j * bsw + i] = src[j * src_stride + i];
1260         }
1261 
1262         if (visible_h < bsh) {
1263           for (j = visible_h; j < bsh; j++)
1264             for (i = 0; i < bsw; i++)
1265               rec[j * bsw + i] = src[j * src_stride + i];
1266         }
1267       }
1268     }
1269   }
1270 
1271   if (x->tune_metric == AOM_TUNE_DAALA_DIST) {
1272     d = (int64_t)od_compute_dist(orig, rec, bsw, bsh, qindex);
1273   } else if (x->tune_metric == AOM_TUNE_CDEF_DIST) {
1274     int coeff_shift = AOMMAX(xd->bd - 8, 0);
1275 
1276     for (i = 0; i < bsh; i += 8) {
1277       for (j = 0; j < bsw; j += 8) {
1278         d += cdef_dist_8x8_16bit(&rec[i * bsw + j], bsw, &orig[i * bsw + j],
1279                                  bsw, coeff_shift);
1280       }
1281     }
1282     if (is_cur_buf_hbd(xd)) d = ((uint64_t)d) >> 2 * coeff_shift;
1283   } else {
1284     // Otherwise, MSE by default
1285     d = pixel_dist_visible_only(cpi, x, src, src_stride, dst, dst_stride,
1286                                 tx_bsize, bsh, bsw, visible_h, visible_w);
1287   }
1288 
1289   return d;
1290 }
1291 
dist_8x8_diff(const MACROBLOCK * x,const uint8_t * src,int src_stride,const int16_t * diff,int diff_stride,int bsw,int bsh,int visible_w,int visible_h,int qindex)1292 static int64_t dist_8x8_diff(const MACROBLOCK *x, const uint8_t *src,
1293                              int src_stride, const int16_t *diff,
1294                              int diff_stride, int bsw, int bsh, int visible_w,
1295                              int visible_h, int qindex) {
1296   int64_t d = 0;
1297   int i, j;
1298   const MACROBLOCKD *xd = &x->e_mbd;
1299 
1300   DECLARE_ALIGNED(16, uint16_t, orig[MAX_SB_SQUARE]);
1301   DECLARE_ALIGNED(16, int16_t, diff16[MAX_SB_SQUARE]);
1302 
1303   assert(bsw >= 8);
1304   assert(bsh >= 8);
1305   assert((bsw & 0x07) == 0);
1306   assert((bsh & 0x07) == 0);
1307 
1308   if (x->tune_metric == AOM_TUNE_CDEF_DIST ||
1309       x->tune_metric == AOM_TUNE_DAALA_DIST) {
1310     if (is_cur_buf_hbd(xd)) {
1311       for (j = 0; j < bsh; j++)
1312         for (i = 0; i < bsw; i++)
1313           orig[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1314     } else {
1315       for (j = 0; j < bsh; j++)
1316         for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
1317     }
1318 
1319     if ((bsw == visible_w) && (bsh == visible_h)) {
1320       for (j = 0; j < bsh; j++)
1321         for (i = 0; i < bsw; i++)
1322           diff16[j * bsw + i] = diff[j * diff_stride + i];
1323     } else {
1324       for (j = 0; j < visible_h; j++)
1325         for (i = 0; i < visible_w; i++)
1326           diff16[j * bsw + i] = diff[j * diff_stride + i];
1327 
1328       if (visible_w < bsw) {
1329         for (j = 0; j < bsh; j++)
1330           for (i = visible_w; i < bsw; i++) diff16[j * bsw + i] = 0;
1331       }
1332 
1333       if (visible_h < bsh) {
1334         for (j = visible_h; j < bsh; j++)
1335           for (i = 0; i < bsw; i++) diff16[j * bsw + i] = 0;
1336       }
1337     }
1338   }
1339 
1340   if (x->tune_metric == AOM_TUNE_DAALA_DIST) {
1341     d = (int64_t)od_compute_dist_diff(orig, diff16, bsw, bsh, qindex);
1342   } else if (x->tune_metric == AOM_TUNE_CDEF_DIST) {
1343     int coeff_shift = AOMMAX(xd->bd - 8, 0);
1344     DECLARE_ALIGNED(16, uint16_t, dst16[MAX_SB_SQUARE]);
1345 
1346     for (i = 0; i < bsh; i++) {
1347       for (j = 0; j < bsw; j++) {
1348         dst16[i * bsw + j] = orig[i * bsw + j] - diff16[i * bsw + j];
1349       }
1350     }
1351 
1352     for (i = 0; i < bsh; i += 8) {
1353       for (j = 0; j < bsw; j += 8) {
1354         d += cdef_dist_8x8_16bit(&dst16[i * bsw + j], bsw, &orig[i * bsw + j],
1355                                  bsw, coeff_shift);
1356       }
1357     }
1358     // Don't scale 'd' for HBD since it will be done by caller side for diff
1359     // input
1360   } else {
1361     // Otherwise, MSE by default
1362     d = aom_sum_squares_2d_i16(diff, diff_stride, visible_w, visible_h);
1363   }
1364 
1365   return d;
1366 }
1367 #endif  // CONFIG_DIST_8X8
1368 
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)1369 static void get_energy_distribution_fine(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1370                                          const uint8_t *src, int src_stride,
1371                                          const uint8_t *dst, int dst_stride,
1372                                          int need_4th, double *hordist,
1373                                          double *verdist) {
1374   const int bw = block_size_wide[bsize];
1375   const int bh = block_size_high[bsize];
1376   unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1377 
1378   if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
1379     // Special cases: calculate 'esq' values manually, as we don't have 'vf'
1380     // functions for the 16 (very small) sub-blocks of this block.
1381     const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
1382     const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
1383     assert(bw <= 32);
1384     assert(bh <= 32);
1385     assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
1386     if (cpi->common.seq_params.use_highbitdepth) {
1387       const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
1388       const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
1389       for (int i = 0; i < bh; ++i)
1390         for (int j = 0; j < bw; ++j) {
1391           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
1392           esq[index] +=
1393               (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
1394               (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
1395         }
1396     } else {
1397       for (int i = 0; i < bh; ++i)
1398         for (int j = 0; j < bw; ++j) {
1399           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
1400           esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
1401                         (src[j + i * src_stride] - dst[j + i * dst_stride]);
1402         }
1403     }
1404   } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
1405     const int f_index =
1406         (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
1407     assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
1408     const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
1409     assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
1410     assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
1411     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
1412     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1413                             &esq[1]);
1414     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1415                             &esq[2]);
1416     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1417                             dst_stride, &esq[3]);
1418     src += bh / 4 * src_stride;
1419     dst += bh / 4 * dst_stride;
1420 
1421     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
1422     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1423                             &esq[5]);
1424     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1425                             &esq[6]);
1426     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1427                             dst_stride, &esq[7]);
1428     src += bh / 4 * src_stride;
1429     dst += bh / 4 * dst_stride;
1430 
1431     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
1432     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1433                             &esq[9]);
1434     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1435                             &esq[10]);
1436     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1437                             dst_stride, &esq[11]);
1438     src += bh / 4 * src_stride;
1439     dst += bh / 4 * dst_stride;
1440 
1441     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
1442     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1443                             &esq[13]);
1444     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1445                             &esq[14]);
1446     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1447                             dst_stride, &esq[15]);
1448   }
1449 
1450   double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
1451                  esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
1452                  esq[12] + esq[13] + esq[14] + esq[15];
1453   if (total > 0) {
1454     const double e_recip = 1.0 / total;
1455     hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
1456     hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
1457     hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
1458     if (need_4th) {
1459       hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
1460     }
1461     verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
1462     verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
1463     verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
1464     if (need_4th) {
1465       verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
1466     }
1467   } else {
1468     hordist[0] = verdist[0] = 0.25;
1469     hordist[1] = verdist[1] = 0.25;
1470     hordist[2] = verdist[2] = 0.25;
1471     if (need_4th) {
1472       hordist[3] = verdist[3] = 0.25;
1473     }
1474   }
1475 }
1476 
adst_vs_flipadst(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride)1477 static int adst_vs_flipadst(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1478                             const uint8_t *src, int src_stride,
1479                             const uint8_t *dst, int dst_stride) {
1480   int prune_bitmask = 0;
1481   double svm_proj_h = 0, svm_proj_v = 0;
1482   double hdist[3] = { 0, 0, 0 }, vdist[3] = { 0, 0, 0 };
1483   get_energy_distribution_fine(cpi, bsize, src, src_stride, dst, dst_stride, 0,
1484                                hdist, vdist);
1485 
1486   svm_proj_v = vdist[0] * ADST_FLIP_SVM[0] + vdist[1] * ADST_FLIP_SVM[1] +
1487                vdist[2] * ADST_FLIP_SVM[2] + ADST_FLIP_SVM[3];
1488   svm_proj_h = hdist[0] * ADST_FLIP_SVM[4] + hdist[1] * ADST_FLIP_SVM[5] +
1489                hdist[2] * ADST_FLIP_SVM[6] + ADST_FLIP_SVM[7];
1490   if (svm_proj_v > FAST_EXT_TX_EDST_MID + FAST_EXT_TX_EDST_MARGIN)
1491     prune_bitmask |= 1 << FLIPADST_1D;
1492   else if (svm_proj_v < FAST_EXT_TX_EDST_MID - FAST_EXT_TX_EDST_MARGIN)
1493     prune_bitmask |= 1 << ADST_1D;
1494 
1495   if (svm_proj_h > FAST_EXT_TX_EDST_MID + FAST_EXT_TX_EDST_MARGIN)
1496     prune_bitmask |= 1 << (FLIPADST_1D + 8);
1497   else if (svm_proj_h < FAST_EXT_TX_EDST_MID - FAST_EXT_TX_EDST_MARGIN)
1498     prune_bitmask |= 1 << (ADST_1D + 8);
1499 
1500   return prune_bitmask;
1501 }
1502 
dct_vs_idtx(const int16_t * diff,int stride,int w,int h)1503 static int dct_vs_idtx(const int16_t *diff, int stride, int w, int h) {
1504   float hcorr, vcorr;
1505   int prune_bitmask = 0;
1506   av1_get_horver_correlation_full(diff, stride, w, h, &hcorr, &vcorr);
1507 
1508   if (vcorr > FAST_EXT_TX_CORR_MID + FAST_EXT_TX_CORR_MARGIN)
1509     prune_bitmask |= 1 << IDTX_1D;
1510   else if (vcorr < FAST_EXT_TX_CORR_MID - FAST_EXT_TX_CORR_MARGIN)
1511     prune_bitmask |= 1 << DCT_1D;
1512 
1513   if (hcorr > FAST_EXT_TX_CORR_MID + FAST_EXT_TX_CORR_MARGIN)
1514     prune_bitmask |= 1 << (IDTX_1D + 8);
1515   else if (hcorr < FAST_EXT_TX_CORR_MID - FAST_EXT_TX_CORR_MARGIN)
1516     prune_bitmask |= 1 << (DCT_1D + 8);
1517   return prune_bitmask;
1518 }
1519 
1520 // Performance drop: 0.5%, Speed improvement: 24%
prune_two_for_sby(const AV1_COMP * cpi,BLOCK_SIZE bsize,MACROBLOCK * x,const MACROBLOCKD * xd,int adst_flipadst,int dct_idtx)1521 static int prune_two_for_sby(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1522                              MACROBLOCK *x, const MACROBLOCKD *xd,
1523                              int adst_flipadst, int dct_idtx) {
1524   int prune = 0;
1525 
1526   if (adst_flipadst) {
1527     const struct macroblock_plane *const p = &x->plane[0];
1528     const struct macroblockd_plane *const pd = &xd->plane[0];
1529     prune |= adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride,
1530                               pd->dst.buf, pd->dst.stride);
1531   }
1532   if (dct_idtx) {
1533     av1_subtract_plane(x, bsize, 0);
1534     const struct macroblock_plane *const p = &x->plane[0];
1535     const int bw = block_size_wide[bsize];
1536     const int bh = block_size_high[bsize];
1537     prune |= dct_vs_idtx(p->src_diff, bw, bw, bh);
1538   }
1539 
1540   return prune;
1541 }
1542 
1543 // Performance drop: 0.3%, Speed improvement: 5%
prune_one_for_sby(const AV1_COMP * cpi,BLOCK_SIZE bsize,const MACROBLOCK * x,const MACROBLOCKD * xd)1544 static int prune_one_for_sby(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1545                              const MACROBLOCK *x, const MACROBLOCKD *xd) {
1546   const struct macroblock_plane *const p = &x->plane[0];
1547   const struct macroblockd_plane *const pd = &xd->plane[0];
1548   return adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride, pd->dst.buf,
1549                           pd->dst.stride);
1550 }
1551 
1552 // 1D Transforms used in inter set, this needs to be changed if
1553 // ext_tx_used_inter is changed
1554 static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
1555   { 1, 0, 0, 0 },
1556   { 1, 1, 1, 1 },
1557   { 1, 1, 1, 1 },
1558   { 1, 0, 0, 1 },
1559 };
1560 
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1561 static void get_energy_distribution_finer(const int16_t *diff, int stride,
1562                                           int bw, int bh, float *hordist,
1563                                           float *verdist) {
1564   // First compute downscaled block energy values (esq); downscale factors
1565   // are defined by w_shift and h_shift.
1566   unsigned int esq[256];
1567   const int w_shift = bw <= 8 ? 0 : 1;
1568   const int h_shift = bh <= 8 ? 0 : 1;
1569   const int esq_w = bw >> w_shift;
1570   const int esq_h = bh >> h_shift;
1571   const int esq_sz = esq_w * esq_h;
1572   int i, j;
1573   memset(esq, 0, esq_sz * sizeof(esq[0]));
1574   if (w_shift) {
1575     for (i = 0; i < bh; i++) {
1576       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1577       const int16_t *cur_diff_row = diff + i * stride;
1578       for (j = 0; j < bw; j += 2) {
1579         cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1580                                 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1581       }
1582     }
1583   } else {
1584     for (i = 0; i < bh; i++) {
1585       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1586       const int16_t *cur_diff_row = diff + i * stride;
1587       for (j = 0; j < bw; j++) {
1588         cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1589       }
1590     }
1591   }
1592 
1593   uint64_t total = 0;
1594   for (i = 0; i < esq_sz; i++) total += esq[i];
1595 
1596   // Output hordist and verdist arrays are normalized 1D projections of esq
1597   if (total == 0) {
1598     float hor_val = 1.0f / esq_w;
1599     for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1600     float ver_val = 1.0f / esq_h;
1601     for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1602     return;
1603   }
1604 
1605   const float e_recip = 1.0f / (float)total;
1606   memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1607   memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1608   const unsigned int *cur_esq_row;
1609   for (i = 0; i < esq_h - 1; i++) {
1610     cur_esq_row = esq + i * esq_w;
1611     for (j = 0; j < esq_w - 1; j++) {
1612       hordist[j] += (float)cur_esq_row[j];
1613       verdist[i] += (float)cur_esq_row[j];
1614     }
1615     verdist[i] += (float)cur_esq_row[j];
1616   }
1617   cur_esq_row = esq + i * esq_w;
1618   for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1619 
1620   for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1621   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1622 }
1623 
1624 // Similar to get_horver_correlation, but also takes into account first
1625 // row/column, when computing horizontal/vertical correlation.
av1_get_horver_correlation_full_c(const int16_t * diff,int stride,int width,int height,float * hcorr,float * vcorr)1626 void av1_get_horver_correlation_full_c(const int16_t *diff, int stride,
1627                                        int width, int height, float *hcorr,
1628                                        float *vcorr) {
1629   // The following notation is used:
1630   // x - current pixel
1631   // y - left neighbor pixel
1632   // z - top neighbor pixel
1633   int64_t x_sum = 0, x2_sum = 0, xy_sum = 0, xz_sum = 0;
1634   int64_t x_firstrow = 0, x_finalrow = 0, x_firstcol = 0, x_finalcol = 0;
1635   int64_t x2_firstrow = 0, x2_finalrow = 0, x2_firstcol = 0, x2_finalcol = 0;
1636 
1637   // First, process horizontal correlation on just the first row
1638   x_sum += diff[0];
1639   x2_sum += diff[0] * diff[0];
1640   x_firstrow += diff[0];
1641   x2_firstrow += diff[0] * diff[0];
1642   for (int j = 1; j < width; ++j) {
1643     const int16_t x = diff[j];
1644     const int16_t y = diff[j - 1];
1645     x_sum += x;
1646     x_firstrow += x;
1647     x2_sum += x * x;
1648     x2_firstrow += x * x;
1649     xy_sum += x * y;
1650   }
1651 
1652   // Process vertical correlation in the first column
1653   x_firstcol += diff[0];
1654   x2_firstcol += diff[0] * diff[0];
1655   for (int i = 1; i < height; ++i) {
1656     const int16_t x = diff[i * stride];
1657     const int16_t z = diff[(i - 1) * stride];
1658     x_sum += x;
1659     x_firstcol += x;
1660     x2_sum += x * x;
1661     x2_firstcol += x * x;
1662     xz_sum += x * z;
1663   }
1664 
1665   // Now process horiz and vert correlation through the rest unit
1666   for (int i = 1; i < height; ++i) {
1667     for (int j = 1; j < width; ++j) {
1668       const int16_t x = diff[i * stride + j];
1669       const int16_t y = diff[i * stride + j - 1];
1670       const int16_t z = diff[(i - 1) * stride + j];
1671       x_sum += x;
1672       x2_sum += x * x;
1673       xy_sum += x * y;
1674       xz_sum += x * z;
1675     }
1676   }
1677 
1678   for (int j = 0; j < width; ++j) {
1679     x_finalrow += diff[(height - 1) * stride + j];
1680     x2_finalrow +=
1681         diff[(height - 1) * stride + j] * diff[(height - 1) * stride + j];
1682   }
1683   for (int i = 0; i < height; ++i) {
1684     x_finalcol += diff[i * stride + width - 1];
1685     x2_finalcol += diff[i * stride + width - 1] * diff[i * stride + width - 1];
1686   }
1687 
1688   int64_t xhor_sum = x_sum - x_finalcol;
1689   int64_t xver_sum = x_sum - x_finalrow;
1690   int64_t y_sum = x_sum - x_firstcol;
1691   int64_t z_sum = x_sum - x_firstrow;
1692   int64_t x2hor_sum = x2_sum - x2_finalcol;
1693   int64_t x2ver_sum = x2_sum - x2_finalrow;
1694   int64_t y2_sum = x2_sum - x2_firstcol;
1695   int64_t z2_sum = x2_sum - x2_firstrow;
1696 
1697   const float num_hor = (float)(height * (width - 1));
1698   const float num_ver = (float)((height - 1) * width);
1699 
1700   const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor;
1701   const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver;
1702 
1703   const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor;
1704   const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver;
1705 
1706   const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor;
1707   const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver;
1708 
1709   if (xhor_var_n > 0 && y_var_n > 0) {
1710     *hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n);
1711     *hcorr = *hcorr < 0 ? 0 : *hcorr;
1712   } else {
1713     *hcorr = 1.0;
1714   }
1715   if (xver_var_n > 0 && z_var_n > 0) {
1716     *vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n);
1717     *vcorr = *vcorr < 0 ? 0 : *vcorr;
1718   } else {
1719     *vcorr = 1.0;
1720   }
1721 }
1722 
1723 // Transforms raw scores into a probability distribution across 16 TX types
score_2D_transform_pow8(float * scores_2D,float shift)1724 static void score_2D_transform_pow8(float *scores_2D, float shift) {
1725   float sum = 0.0f;
1726   int i;
1727   for (i = 0; i < 16; i++) {
1728     const float v = AOMMIN(AOMMAX(scores_2D[i] + shift, 0.0f), 100.0f);
1729     const float v2 = v * v;
1730     const float v4 = v2 * v2;
1731     scores_2D[i] = v4 * v4;
1732     sum += scores_2D[i];
1733   }
1734   for (i = 0; i < 16; i++) {
1735     if (scores_2D[i] < sum * 1e-4)
1736       scores_2D[i] = 0.0f;
1737     else
1738       scores_2D[i] /= sum;
1739   }
1740 }
1741 
1742 // These thresholds were calibrated to provide a certain number of TX types
1743 // pruned by the model on average, i.e. selecting a threshold with index i
1744 // will lead to pruning i+1 TX types on average
1745 static const float *prune_2D_adaptive_thresholds[] = {
1746   // TX_4X4
1747   (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1748              0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1749              0.09778f, 0.11780f },
1750   // TX_8X8
1751   (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1752              0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1753              0.10803f, 0.14124f },
1754   // TX_16X16
1755   (float[]){ 0.01404f, 0.02820f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1756              0.06897f, 0.07629f, 0.08875f, 0.11169f },
1757   // TX_32X32
1758   NULL,
1759   // TX_64X64
1760   NULL,
1761   // TX_4X8
1762   (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1763              0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1764              0.10168f, 0.12585f },
1765   // TX_8X4
1766   (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1767              0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1768              0.10583f, 0.13123f },
1769   // TX_8X16
1770   (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1771              0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1772              0.10730f, 0.14221f },
1773   // TX_16X8
1774   (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1775              0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1776              0.10339f, 0.13464f },
1777   // TX_16X32
1778   NULL,
1779   // TX_32X16
1780   NULL,
1781   // TX_32X64
1782   NULL,
1783   // TX_64X32
1784   NULL,
1785   // TX_4X16
1786   (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1787              0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1788              0.10242f, 0.12878f },
1789   // TX_16X4
1790   (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1791              0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1792              0.10217f, 0.12610f },
1793   // TX_8X32
1794   NULL,
1795   // TX_32X8
1796   NULL,
1797   // TX_16X64
1798   NULL,
1799   // TX_64X16
1800   NULL,
1801 };
1802 
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_mode)1803 static uint16_t prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1804                             int blk_row, int blk_col, TxSetType tx_set_type,
1805                             TX_TYPE_PRUNE_MODE prune_mode) {
1806   static const int tx_type_table_2D[16] = {
1807     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1808     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1809     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1810     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1811   };
1812   if (tx_set_type != EXT_TX_SET_ALL16 &&
1813       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1814     return 0;
1815   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1816   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1817   if (!nn_config_hor || !nn_config_ver) return 0;  // Model not established yet.
1818 
1819   aom_clear_system_state();
1820   float hfeatures[16], vfeatures[16];
1821   float hscores[4], vscores[4];
1822   float scores_2D[16];
1823   const int bw = tx_size_wide[tx_size];
1824   const int bh = tx_size_high[tx_size];
1825   const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1826   const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1827   assert(hfeatures_num <= 16);
1828   assert(vfeatures_num <= 16);
1829 
1830   const struct macroblock_plane *const p = &x->plane[0];
1831   const int diff_stride = block_size_wide[bsize];
1832   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1833   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1834                                 vfeatures);
1835   av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1836                                   &hfeatures[hfeatures_num - 1],
1837                                   &vfeatures[vfeatures_num - 1]);
1838   av1_nn_predict(hfeatures, nn_config_hor, hscores);
1839   av1_nn_predict(vfeatures, nn_config_ver, vscores);
1840   aom_clear_system_state();
1841 
1842   float score_2D_average = 0.0f;
1843   for (int i = 0; i < 4; i++) {
1844     float *cur_scores_2D = scores_2D + i * 4;
1845     cur_scores_2D[0] = vscores[i] * hscores[0];
1846     cur_scores_2D[1] = vscores[i] * hscores[1];
1847     cur_scores_2D[2] = vscores[i] * hscores[2];
1848     cur_scores_2D[3] = vscores[i] * hscores[3];
1849     score_2D_average += cur_scores_2D[0] + cur_scores_2D[1] + cur_scores_2D[2] +
1850                         cur_scores_2D[3];
1851   }
1852   score_2D_average /= 16;
1853 
1854   const int prune_aggr_table[2][2] = { { 6, 4 }, { 10, 7 } };
1855   int pruning_aggressiveness = 1;
1856   if (tx_set_type == EXT_TX_SET_ALL16) {
1857     score_2D_transform_pow8(scores_2D, (10 - score_2D_average));
1858     pruning_aggressiveness =
1859         prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][0];
1860   } else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT) {
1861     score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
1862     pruning_aggressiveness =
1863         prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][1];
1864   }
1865 
1866   // Always keep the TX type with the highest score, prune all others with
1867   // score below score_thresh.
1868   int max_score_i = 0;
1869   float max_score = 0.0f;
1870   for (int i = 0; i < 16; i++) {
1871     if (scores_2D[i] > max_score &&
1872         av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
1873       max_score = scores_2D[i];
1874       max_score_i = i;
1875     }
1876   }
1877 
1878   const float score_thresh =
1879       prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness - 1];
1880 
1881   uint16_t prune_bitmask = 0;
1882   for (int i = 0; i < 16; i++) {
1883     if (scores_2D[i] < score_thresh && i != max_score_i)
1884       prune_bitmask |= (1 << tx_type_table_2D[i]);
1885   }
1886   return prune_bitmask;
1887 }
1888 
1889 // ((prune >> vtx_tab[tx_type]) & 1)
1890 static const uint16_t prune_v_mask[] = {
1891   0x0000, 0x0425, 0x108a, 0x14af, 0x4150, 0x4575, 0x51da, 0x55ff,
1892   0xaa00, 0xae25, 0xba8a, 0xbeaf, 0xeb50, 0xef75, 0xfbda, 0xffff,
1893 };
1894 
1895 // ((prune >> (htx_tab[tx_type] + 8)) & 1)
1896 static const uint16_t prune_h_mask[] = {
1897   0x0000, 0x0813, 0x210c, 0x291f, 0x80e0, 0x88f3, 0xa1ec, 0xa9ff,
1898   0x5600, 0x5e13, 0x770c, 0x7f1f, 0xd6e0, 0xdef3, 0xf7ec, 0xffff,
1899 };
1900 
gen_tx_search_prune_mask(int tx_search_prune)1901 static INLINE uint16_t gen_tx_search_prune_mask(int tx_search_prune) {
1902   uint8_t prune_v = tx_search_prune & 0x0F;
1903   uint8_t prune_h = (tx_search_prune >> 8) & 0x0F;
1904   return (prune_v_mask[prune_v] & prune_h_mask[prune_h]);
1905 }
1906 
prune_tx(const AV1_COMP * cpi,BLOCK_SIZE bsize,MACROBLOCK * x,const MACROBLOCKD * const xd,int tx_set_type)1907 static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
1908                      const MACROBLOCKD *const xd, int tx_set_type) {
1909   x->tx_search_prune[tx_set_type] = 0;
1910   x->tx_split_prune_flag = 0;
1911   const MB_MODE_INFO *mbmi = xd->mi[0];
1912   const int is_inter = is_inter_block(mbmi);
1913   if ((is_inter && cpi->oxcf.use_inter_dct_only) ||
1914       (!is_inter && cpi->oxcf.use_intra_dct_only)) {
1915     x->tx_search_prune[tx_set_type] = ~(1 << DCT_DCT);
1916     return;
1917   }
1918   if (!is_inter || cpi->sf.tx_type_search.prune_mode == NO_PRUNE ||
1919       x->use_default_inter_tx_type || xd->lossless[mbmi->segment_id] ||
1920       x->cb_partition_scan)
1921     return;
1922   int tx_set = ext_tx_set_index[1][tx_set_type];
1923   assert(tx_set >= 0);
1924   const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
1925   int prune = 0;
1926   switch (cpi->sf.tx_type_search.prune_mode) {
1927     case NO_PRUNE: return;
1928     case PRUNE_ONE:
1929       if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return;
1930       prune = prune_one_for_sby(cpi, bsize, x, xd);
1931       x->tx_search_prune[tx_set_type] = gen_tx_search_prune_mask(prune);
1932       break;
1933     case PRUNE_TWO:
1934       if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
1935         if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return;
1936         prune = prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
1937       } else if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) {
1938         prune = prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
1939       } else {
1940         prune = prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
1941       }
1942       x->tx_search_prune[tx_set_type] = gen_tx_search_prune_mask(prune);
1943       break;
1944     case PRUNE_2D_ACCURATE:
1945     case PRUNE_2D_FAST: break;
1946     default: assert(0);
1947   }
1948 }
1949 
model_rd_from_sse(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)1950 static void model_rd_from_sse(const AV1_COMP *const cpi,
1951                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
1952                               int plane, int64_t sse, int num_samples,
1953                               int *rate, int64_t *dist) {
1954   (void)num_samples;
1955   const MACROBLOCKD *const xd = &x->e_mbd;
1956   const struct macroblockd_plane *const pd = &xd->plane[plane];
1957   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
1958 
1959   // Fast approximate the modelling function.
1960   if (cpi->sf.simple_model_rd_from_var) {
1961     const int64_t square_error = sse;
1962     int quantizer = pd->dequant_Q3[1] >> dequant_shift;
1963     if (quantizer < 120)
1964       *rate = (int)AOMMIN(
1965           (square_error * (280 - quantizer)) >> (16 - AV1_PROB_COST_SHIFT),
1966           INT_MAX);
1967     else
1968       *rate = 0;
1969     assert(*rate >= 0);
1970     *dist = (square_error * quantizer) >> 8;
1971   } else {
1972     av1_model_rd_from_var_lapndz(sse, num_pels_log2_lookup[plane_bsize],
1973                                  pd->dequant_Q3[1] >> dequant_shift, rate,
1974                                  dist);
1975   }
1976   *dist <<= 4;
1977 }
1978 
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)1979 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
1980   const AV1_COMMON *cm = &cpi->common;
1981   const int num_planes = av1_num_planes(cm);
1982   const MACROBLOCKD *xd = &x->e_mbd;
1983   const MB_MODE_INFO *mbmi = xd->mi[0];
1984   int64_t total_sse = 0;
1985   for (int plane = 0; plane < num_planes; ++plane) {
1986     const struct macroblock_plane *const p = &x->plane[plane];
1987     const struct macroblockd_plane *const pd = &xd->plane[plane];
1988     const BLOCK_SIZE bs = get_plane_block_size(mbmi->sb_type, pd->subsampling_x,
1989                                                pd->subsampling_y);
1990     unsigned int sse;
1991 
1992     if (x->skip_chroma_rd && plane) continue;
1993 
1994     cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
1995                        &sse);
1996     total_sse += sse;
1997   }
1998   total_sse <<= 4;
1999   return total_sse;
2000 }
2001 
model_rd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)2002 static void model_rd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
2003                             MACROBLOCK *x, MACROBLOCKD *xd, int plane_from,
2004                             int plane_to, int mi_row, int mi_col,
2005                             int *out_rate_sum, int64_t *out_dist_sum,
2006                             int *skip_txfm_sb, int64_t *skip_sse_sb,
2007                             int *plane_rate, int64_t *plane_sse,
2008                             int64_t *plane_dist) {
2009   // Note our transform coeffs are 8 times an orthogonal transform.
2010   // Hence quantizer step is also 8 times. To get effective quantizer
2011   // we need to divide by 8 before sending to modeling function.
2012   int plane;
2013   (void)mi_row;
2014   (void)mi_col;
2015   const int ref = xd->mi[0]->ref_frame[0];
2016 
2017   int64_t rate_sum = 0;
2018   int64_t dist_sum = 0;
2019   int64_t total_sse = 0;
2020 
2021   for (plane = plane_from; plane <= plane_to; ++plane) {
2022     struct macroblock_plane *const p = &x->plane[plane];
2023     struct macroblockd_plane *const pd = &xd->plane[plane];
2024     const BLOCK_SIZE plane_bsize =
2025         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
2026     const int bw = block_size_wide[plane_bsize];
2027     const int bh = block_size_high[plane_bsize];
2028     int64_t sse;
2029     int rate;
2030     int64_t dist;
2031 
2032     if (x->skip_chroma_rd && plane) continue;
2033 
2034     if (is_cur_buf_hbd(xd)) {
2035       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
2036                            pd->dst.stride, bw, bh);
2037     } else {
2038       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
2039                     bh);
2040     }
2041     sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2);
2042 
2043     model_rd_from_sse(cpi, x, plane_bsize, plane, sse, bw * bh, &rate, &dist);
2044 
2045     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
2046 
2047     total_sse += sse;
2048     rate_sum += rate;
2049     dist_sum += dist;
2050     if (plane_rate) plane_rate[plane] = rate;
2051     if (plane_sse) plane_sse[plane] = sse;
2052     if (plane_dist) plane_dist[plane] = dist;
2053     assert(rate_sum >= 0);
2054   }
2055 
2056   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
2057   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
2058   rate_sum = AOMMIN(rate_sum, INT_MAX);
2059   *out_rate_sum = (int)rate_sum;
2060   *out_dist_sum = dist_sum;
2061 }
2062 
av1_block_error_c(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,int64_t * ssz)2063 int64_t av1_block_error_c(const tran_low_t *coeff, const tran_low_t *dqcoeff,
2064                           intptr_t block_size, int64_t *ssz) {
2065   int i;
2066   int64_t error = 0, sqcoeff = 0;
2067 
2068   for (i = 0; i < block_size; i++) {
2069     const int diff = coeff[i] - dqcoeff[i];
2070     error += diff * diff;
2071     sqcoeff += coeff[i] * coeff[i];
2072   }
2073 
2074   *ssz = sqcoeff;
2075   return error;
2076 }
2077 
av1_highbd_block_error_c(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,int64_t * ssz,int bd)2078 int64_t av1_highbd_block_error_c(const tran_low_t *coeff,
2079                                  const tran_low_t *dqcoeff, intptr_t block_size,
2080                                  int64_t *ssz, int bd) {
2081   int i;
2082   int64_t error = 0, sqcoeff = 0;
2083   int shift = 2 * (bd - 8);
2084   int rounding = shift > 0 ? 1 << (shift - 1) : 0;
2085 
2086   for (i = 0; i < block_size; i++) {
2087     const int64_t diff = coeff[i] - dqcoeff[i];
2088     error += diff * diff;
2089     sqcoeff += (int64_t)coeff[i] * (int64_t)coeff[i];
2090   }
2091   assert(error >= 0 && sqcoeff >= 0);
2092   error = (error + rounding) >> shift;
2093   sqcoeff = (sqcoeff + rounding) >> shift;
2094 
2095   *ssz = sqcoeff;
2096   return error;
2097 }
2098 
2099 // Get transform block visible dimensions cropped to the MI units.
get_txb_dimensions(const MACROBLOCKD * xd,int plane,BLOCK_SIZE plane_bsize,int blk_row,int blk_col,BLOCK_SIZE tx_bsize,int * width,int * height,int * visible_width,int * visible_height)2100 static void get_txb_dimensions(const MACROBLOCKD *xd, int plane,
2101                                BLOCK_SIZE plane_bsize, int blk_row, int blk_col,
2102                                BLOCK_SIZE tx_bsize, int *width, int *height,
2103                                int *visible_width, int *visible_height) {
2104   assert(tx_bsize <= plane_bsize);
2105   int txb_height = block_size_high[tx_bsize];
2106   int txb_width = block_size_wide[tx_bsize];
2107   const int block_height = block_size_high[plane_bsize];
2108   const int block_width = block_size_wide[plane_bsize];
2109   const struct macroblockd_plane *const pd = &xd->plane[plane];
2110   // TODO(aconverse@google.com): Investigate using crop_width/height here rather
2111   // than the MI size
2112   const int block_rows =
2113       (xd->mb_to_bottom_edge >= 0)
2114           ? block_height
2115           : (xd->mb_to_bottom_edge >> (3 + pd->subsampling_y)) + block_height;
2116   const int block_cols =
2117       (xd->mb_to_right_edge >= 0)
2118           ? block_width
2119           : (xd->mb_to_right_edge >> (3 + pd->subsampling_x)) + block_width;
2120   const int tx_unit_size = tx_size_wide_log2[0];
2121   if (width) *width = txb_width;
2122   if (height) *height = txb_height;
2123   *visible_width = clamp(block_cols - (blk_col << tx_unit_size), 0, txb_width);
2124   *visible_height =
2125       clamp(block_rows - (blk_row << tx_unit_size), 0, txb_height);
2126 }
2127 
2128 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
2129 // the
2130 // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)2131 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
2132                            int plane, const uint8_t *src, const int src_stride,
2133                            const uint8_t *dst, const int dst_stride,
2134                            int blk_row, int blk_col,
2135                            const BLOCK_SIZE plane_bsize,
2136                            const BLOCK_SIZE tx_bsize) {
2137   int txb_rows, txb_cols, visible_rows, visible_cols;
2138   const MACROBLOCKD *xd = &x->e_mbd;
2139 
2140   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
2141                      &txb_cols, &txb_rows, &visible_cols, &visible_rows);
2142   assert(visible_rows > 0);
2143   assert(visible_cols > 0);
2144 
2145 #if CONFIG_DIST_8X8
2146   if (x->using_dist_8x8 && plane == 0)
2147     return (unsigned)av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride,
2148                                   tx_bsize, txb_cols, txb_rows, visible_cols,
2149                                   visible_rows, x->qindex);
2150 #endif  // CONFIG_DIST_8X8
2151 
2152   unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
2153                                          dst_stride, tx_bsize, txb_rows,
2154                                          txb_cols, visible_rows, visible_cols);
2155 
2156   return sse;
2157 }
2158 
2159 // Compute the pixel domain distortion from diff on all visible 4x4s in the
2160 // transform block.
pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8)2161 static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
2162                                       int blk_row, int blk_col,
2163                                       const BLOCK_SIZE plane_bsize,
2164                                       const BLOCK_SIZE tx_bsize,
2165                                       unsigned int *block_mse_q8) {
2166   int visible_rows, visible_cols;
2167   const MACROBLOCKD *xd = &x->e_mbd;
2168   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
2169                      NULL, &visible_cols, &visible_rows);
2170   const int diff_stride = block_size_wide[plane_bsize];
2171   const int16_t *diff = x->plane[plane].src_diff;
2172 #if CONFIG_DIST_8X8
2173   int txb_height = block_size_high[tx_bsize];
2174   int txb_width = block_size_wide[tx_bsize];
2175   if (x->using_dist_8x8 && plane == 0) {
2176     const int src_stride = x->plane[plane].src.stride;
2177     const int src_idx = (blk_row * src_stride + blk_col)
2178                         << tx_size_wide_log2[0];
2179     const int diff_idx = (blk_row * diff_stride + blk_col)
2180                          << tx_size_wide_log2[0];
2181     const uint8_t *src = &x->plane[plane].src.buf[src_idx];
2182     return dist_8x8_diff(x, src, src_stride, diff + diff_idx, diff_stride,
2183                          txb_width, txb_height, visible_cols, visible_rows,
2184                          x->qindex);
2185   }
2186 #endif
2187   diff += ((blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]);
2188   uint64_t sse =
2189       aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
2190   if (block_mse_q8 != NULL)
2191     *block_mse_q8 = (unsigned int)((256 * sse) / (visible_cols * visible_rows));
2192   return sse;
2193 }
2194 
av1_count_colors(const uint8_t * src,int stride,int rows,int cols,int * val_count)2195 int av1_count_colors(const uint8_t *src, int stride, int rows, int cols,
2196                      int *val_count) {
2197   const int max_pix_val = 1 << 8;
2198   memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
2199   for (int r = 0; r < rows; ++r) {
2200     for (int c = 0; c < cols; ++c) {
2201       const int this_val = src[r * stride + c];
2202       assert(this_val < max_pix_val);
2203       ++val_count[this_val];
2204     }
2205   }
2206   int n = 0;
2207   for (int i = 0; i < max_pix_val; ++i) {
2208     if (val_count[i]) ++n;
2209   }
2210   return n;
2211 }
2212 
av1_count_colors_highbd(const uint8_t * src8,int stride,int rows,int cols,int bit_depth,int * val_count)2213 int av1_count_colors_highbd(const uint8_t *src8, int stride, int rows, int cols,
2214                             int bit_depth, int *val_count) {
2215   assert(bit_depth <= 12);
2216   const int max_pix_val = 1 << bit_depth;
2217   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
2218   memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
2219   for (int r = 0; r < rows; ++r) {
2220     for (int c = 0; c < cols; ++c) {
2221       const int this_val = src[r * stride + c];
2222       assert(this_val < max_pix_val);
2223       if (this_val >= max_pix_val) return 0;
2224       ++val_count[this_val];
2225     }
2226   }
2227   int n = 0;
2228   for (int i = 0; i < max_pix_val; ++i) {
2229     if (val_count[i]) ++n;
2230   }
2231   return n;
2232 }
2233 
inverse_transform_block_facade(MACROBLOCKD * xd,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)2234 static void inverse_transform_block_facade(MACROBLOCKD *xd, int plane,
2235                                            int block, int blk_row, int blk_col,
2236                                            int eob, int reduced_tx_set) {
2237   struct macroblockd_plane *const pd = &xd->plane[plane];
2238   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
2239   const PLANE_TYPE plane_type = get_plane_type(plane);
2240   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
2241   const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col,
2242                                           tx_size, reduced_tx_set);
2243   const int dst_stride = pd->dst.stride;
2244   uint8_t *dst =
2245       &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
2246   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
2247                               dst_stride, eob, reduced_tx_set);
2248 }
2249 
2250 static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record, const uint32_t hash);
2251 
get_intra_txb_hash(MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size)2252 static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
2253                                    int blk_col, BLOCK_SIZE plane_bsize,
2254                                    TX_SIZE tx_size) {
2255   int16_t tmp_data[64 * 64];
2256   const int diff_stride = block_size_wide[plane_bsize];
2257   const int16_t *diff = x->plane[plane].src_diff;
2258   const int16_t *cur_diff_row = diff + 4 * blk_row * diff_stride + 4 * blk_col;
2259   const int txb_w = tx_size_wide[tx_size];
2260   const int txb_h = tx_size_high[tx_size];
2261   uint8_t *hash_data = (uint8_t *)cur_diff_row;
2262   if (txb_w != diff_stride) {
2263     int16_t *cur_hash_row = tmp_data;
2264     for (int i = 0; i < txb_h; i++) {
2265       memcpy(cur_hash_row, cur_diff_row, sizeof(*diff) * txb_w);
2266       cur_hash_row += txb_w;
2267       cur_diff_row += diff_stride;
2268     }
2269     hash_data = (uint8_t *)tmp_data;
2270   }
2271   CRC32C *crc = &x->mb_rd_record.crc_calculator;
2272   const uint32_t hash = av1_get_crc32c_value(crc, hash_data, 2 * txb_w * txb_h);
2273   return (hash << 5) + tx_size;
2274 }
2275 
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int64_t * out_dist,int64_t * out_sse)2276 static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
2277                                         TX_SIZE tx_size, int64_t *out_dist,
2278                                         int64_t *out_sse) {
2279   MACROBLOCKD *const xd = &x->e_mbd;
2280   const struct macroblock_plane *const p = &x->plane[plane];
2281   const struct macroblockd_plane *const pd = &xd->plane[plane];
2282   // Transform domain distortion computation is more efficient as it does
2283   // not involve an inverse transform, but it is less accurate.
2284   const int buffer_length = av1_get_max_eob(tx_size);
2285   int64_t this_sse;
2286   // TX-domain results need to shift down to Q2/D10 to match pixel
2287   // domain distortion values which are in Q2^2
2288   int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
2289   tran_low_t *const coeff = BLOCK_OFFSET(p->coeff, block);
2290   tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
2291 
2292   if (is_cur_buf_hbd(xd))
2293     *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
2294                                        xd->bd);
2295   else
2296     *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
2297 
2298   *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
2299   *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
2300 }
2301 
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)2302 static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
2303                                            int plane, BLOCK_SIZE plane_bsize,
2304                                            int block, int blk_row, int blk_col,
2305                                            TX_SIZE tx_size) {
2306   MACROBLOCKD *const xd = &x->e_mbd;
2307   const struct macroblock_plane *const p = &x->plane[plane];
2308   const struct macroblockd_plane *const pd = &xd->plane[plane];
2309   const uint16_t eob = p->eobs[block];
2310   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
2311   const int bsw = block_size_wide[tx_bsize];
2312   const int bsh = block_size_high[tx_bsize];
2313   const int src_stride = x->plane[plane].src.stride;
2314   const int dst_stride = xd->plane[plane].dst.stride;
2315   // Scale the transform block index to pixel unit.
2316   const int src_idx = (blk_row * src_stride + blk_col) << tx_size_wide_log2[0];
2317   const int dst_idx = (blk_row * dst_stride + blk_col) << tx_size_wide_log2[0];
2318   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
2319   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
2320   const tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
2321 
2322   assert(cpi != NULL);
2323   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
2324 
2325   uint8_t *recon;
2326   DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
2327 
2328   if (is_cur_buf_hbd(xd)) {
2329     recon = CONVERT_TO_BYTEPTR(recon16);
2330     av1_highbd_convolve_2d_copy_sr(CONVERT_TO_SHORTPTR(dst), dst_stride,
2331                                    CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw,
2332                                    bsh, NULL, NULL, 0, 0, NULL, xd->bd);
2333   } else {
2334     recon = (uint8_t *)recon16;
2335     av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
2336                             NULL, 0, 0, NULL);
2337   }
2338 
2339   const PLANE_TYPE plane_type = get_plane_type(plane);
2340   TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col, tx_size,
2341                                     cpi->common.reduced_tx_set_used);
2342   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
2343                               MAX_TX_SIZE, eob,
2344                               cpi->common.reduced_tx_set_used);
2345 
2346   return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
2347                          blk_row, blk_col, plane_bsize, tx_bsize);
2348 }
2349 
get_diff_mean(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int w,int h)2350 static double get_diff_mean(const uint8_t *src, int src_stride,
2351                             const uint8_t *dst, int dst_stride, int w, int h) {
2352   double sum = 0.0;
2353   for (int j = 0; j < h; ++j) {
2354     for (int i = 0; i < w; ++i) {
2355       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
2356       sum += diff;
2357     }
2358   }
2359   assert(w > 0 && h > 0);
2360   return sum / (w * h);
2361 }
2362 
get_highbd_diff_mean(const uint8_t * src8,int src_stride,const uint8_t * dst8,int dst_stride,int w,int h)2363 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
2364                                    const uint8_t *dst8, int dst_stride, int w,
2365                                    int h) {
2366   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
2367   const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
2368   double sum = 0.0;
2369   for (int j = 0; j < h; ++j) {
2370     for (int i = 0; i < w; ++i) {
2371       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
2372       sum += diff;
2373     }
2374   }
2375   assert(w > 0 && h > 0);
2376   return sum / (w * h);
2377 }
2378 
get_sse_norm(const int16_t * diff,int stride,int w,int h)2379 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
2380   double sum = 0.0;
2381   for (int j = 0; j < h; ++j) {
2382     for (int i = 0; i < w; ++i) {
2383       const int err = diff[j * stride + i];
2384       sum += err * err;
2385     }
2386   }
2387   assert(w > 0 && h > 0);
2388   return sum / (w * h);
2389 }
2390 
get_sad_norm(const int16_t * diff,int stride,int w,int h)2391 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
2392   double sum = 0.0;
2393   for (int j = 0; j < h; ++j) {
2394     for (int i = 0; i < w; ++i) {
2395       sum += abs(diff[j * stride + i]);
2396     }
2397   }
2398   assert(w > 0 && h > 0);
2399   return sum / (w * h);
2400 }
2401 
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)2402 static void get_2x2_normalized_sses_and_sads(
2403     const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
2404     int src_stride, const uint8_t *const dst, int dst_stride,
2405     const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
2406     double *const sad_norm_arr) {
2407   const BLOCK_SIZE tx_bsize_half =
2408       get_partition_subsize(tx_bsize, PARTITION_SPLIT);
2409   if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
2410     const int half_width = block_size_wide[tx_bsize] / 2;
2411     const int half_height = block_size_high[tx_bsize] / 2;
2412     for (int row = 0; row < 2; ++row) {
2413       for (int col = 0; col < 2; ++col) {
2414         const int16_t *const this_src_diff =
2415             src_diff + row * half_height * diff_stride + col * half_width;
2416         if (sse_norm_arr) {
2417           sse_norm_arr[row * 2 + col] =
2418               get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
2419         }
2420         if (sad_norm_arr) {
2421           sad_norm_arr[row * 2 + col] =
2422               get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
2423         }
2424       }
2425     }
2426   } else {  // use function pointers to calculate stats
2427     const int half_width = block_size_wide[tx_bsize_half];
2428     const int half_height = block_size_high[tx_bsize_half];
2429     const int num_samples_half = half_width * half_height;
2430     for (int row = 0; row < 2; ++row) {
2431       for (int col = 0; col < 2; ++col) {
2432         const uint8_t *const this_src =
2433             src + row * half_height * src_stride + col * half_width;
2434         const uint8_t *const this_dst =
2435             dst + row * half_height * dst_stride + col * half_width;
2436 
2437         if (sse_norm_arr) {
2438           unsigned int this_sse;
2439           cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
2440                                         dst_stride, &this_sse);
2441           sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
2442         }
2443 
2444         if (sad_norm_arr) {
2445           const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf(
2446               this_src, src_stride, this_dst, dst_stride);
2447           sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
2448         }
2449       }
2450     }
2451   }
2452 }
2453 
2454 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
2455 // 0: Do not collect any RD stats
2456 // 1: Collect RD stats for transform units
2457 // 2: Collect RD stats for partition units
2458 #if CONFIG_COLLECT_RD_STATS
2459 
2460 #if CONFIG_COLLECT_RD_STATS == 1
get_mean(const int16_t * diff,int stride,int w,int h)2461 static double get_mean(const int16_t *diff, int stride, int w, int h) {
2462   double sum = 0.0;
2463   for (int j = 0; j < h; ++j) {
2464     for (int i = 0; i < w; ++i) {
2465       sum += diff[j * stride + i];
2466     }
2467   }
2468   assert(w > 0 && h > 0);
2469   return sum / (w * h);
2470 }
2471 
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)2472 static void PrintTransformUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x,
2473                                     const RD_STATS *const rd_stats, int blk_row,
2474                                     int blk_col, BLOCK_SIZE plane_bsize,
2475                                     TX_SIZE tx_size, TX_TYPE tx_type,
2476                                     int64_t rd) {
2477   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
2478 
2479   // Generate small sample to restrict output size.
2480   static unsigned int seed = 21743;
2481   if (lcg_rand16(&seed) % 256 > 0) return;
2482 
2483   const char output_file[] = "tu_stats.txt";
2484   FILE *fout = fopen(output_file, "a");
2485   if (!fout) return;
2486 
2487   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
2488   const MACROBLOCKD *const xd = &x->e_mbd;
2489   const int plane = 0;
2490   struct macroblock_plane *const p = &x->plane[plane];
2491   const struct macroblockd_plane *const pd = &xd->plane[plane];
2492   const int txw = tx_size_wide[tx_size];
2493   const int txh = tx_size_high[tx_size];
2494   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2495   const int q_step = pd->dequant_Q3[1] >> dequant_shift;
2496   const int num_samples = txw * txh;
2497 
2498   const double rate_norm = (double)rd_stats->rate / num_samples;
2499   const double dist_norm = (double)rd_stats->dist / num_samples;
2500 
2501   fprintf(fout, "%g %g", rate_norm, dist_norm);
2502 
2503   const int src_stride = p->src.stride;
2504   const uint8_t *const src =
2505       &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
2506   const int dst_stride = pd->dst.stride;
2507   const uint8_t *const dst =
2508       &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
2509   unsigned int sse;
2510   cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
2511   const double sse_norm = (double)sse / num_samples;
2512 
2513   const unsigned int sad =
2514       cpi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
2515   const double sad_norm = (double)sad / num_samples;
2516 
2517   fprintf(fout, " %g %g", sse_norm, sad_norm);
2518 
2519   const int diff_stride = block_size_wide[plane_bsize];
2520   const int16_t *const src_diff =
2521       &p->src_diff[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
2522 
2523   double sse_norm_arr[4], sad_norm_arr[4];
2524   get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
2525                                    dst_stride, src_diff, diff_stride,
2526                                    sse_norm_arr, sad_norm_arr);
2527   for (int i = 0; i < 4; ++i) {
2528     fprintf(fout, " %g", sse_norm_arr[i]);
2529   }
2530   for (int i = 0; i < 4; ++i) {
2531     fprintf(fout, " %g", sad_norm_arr[i]);
2532   }
2533 
2534   const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
2535   const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
2536 
2537   fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
2538           tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
2539 
2540   int model_rate;
2541   int64_t model_dist;
2542   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
2543                                    &model_rate, &model_dist);
2544   const double model_rate_norm = (double)model_rate / num_samples;
2545   const double model_dist_norm = (double)model_dist / num_samples;
2546   fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
2547 
2548   const double mean = get_mean(src_diff, diff_stride, txw, txh);
2549   float hor_corr, vert_corr;
2550   av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
2551                                   &vert_corr);
2552   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
2553 
2554   double hdist[4] = { 0 }, vdist[4] = { 0 };
2555   get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
2556                                1, hdist, vdist);
2557   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
2558           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
2559 
2560   fprintf(fout, " %d %" PRId64, x->rdmult, rd);
2561 
2562   fprintf(fout, "\n");
2563   fclose(fout);
2564 }
2565 #endif  // CONFIG_COLLECT_RD_STATS == 1
2566 
2567 #if CONFIG_COLLECT_RD_STATS >= 2
PrintPredictionUnitStats(const AV1_COMP * const cpi,const TileDataEnc * tile_data,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)2568 static void PrintPredictionUnitStats(const AV1_COMP *const cpi,
2569                                      const TileDataEnc *tile_data,
2570                                      MACROBLOCK *x,
2571                                      const RD_STATS *const rd_stats,
2572                                      BLOCK_SIZE plane_bsize) {
2573   if (rd_stats->invalid_rate) return;
2574   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
2575 
2576   if (cpi->sf.inter_mode_rd_model_estimation == 1 &&
2577       (tile_data == NULL ||
2578        !tile_data->inter_mode_rd_models[plane_bsize].ready))
2579     return;
2580   (void)tile_data;
2581   // Generate small sample to restrict output size.
2582   static unsigned int seed = 95014;
2583 
2584   if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
2585       1)
2586     return;
2587 
2588   const char output_file[] = "pu_stats.txt";
2589   FILE *fout = fopen(output_file, "a");
2590   if (!fout) return;
2591 
2592   const MACROBLOCKD *const xd = &x->e_mbd;
2593   const int plane = 0;
2594   struct macroblock_plane *const p = &x->plane[plane];
2595   const struct macroblockd_plane *const pd = &xd->plane[plane];
2596   const int diff_stride = block_size_wide[plane_bsize];
2597   int bw, bh;
2598   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
2599                      &bh);
2600   const int num_samples = bw * bh;
2601   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2602   const int q_step = pd->dequant_Q3[1] >> dequant_shift;
2603 
2604   const double rate_norm = (double)rd_stats->rate / num_samples;
2605   const double dist_norm = (double)rd_stats->dist / num_samples;
2606   const double rdcost_norm =
2607       (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
2608 
2609   fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
2610 
2611   const int src_stride = p->src.stride;
2612   const uint8_t *const src = p->src.buf;
2613   const int dst_stride = pd->dst.stride;
2614   const uint8_t *const dst = pd->dst.buf;
2615   const int16_t *const src_diff = p->src_diff;
2616   const int shift = (xd->bd - 8);
2617 
2618   int64_t sse;
2619   if (is_cur_buf_hbd(xd)) {
2620     sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
2621                          bw, bh);
2622   } else {
2623     sse =
2624         aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw, bh);
2625   }
2626   sse = ROUND_POWER_OF_TWO(sse, shift * 2);
2627   const double sse_norm = (double)sse / num_samples;
2628 
2629   const unsigned int sad =
2630       cpi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
2631   const double sad_norm =
2632       (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
2633 
2634   fprintf(fout, " %g %g", sse_norm, sad_norm);
2635 
2636   double sse_norm_arr[4], sad_norm_arr[4];
2637   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
2638                                    dst_stride, src_diff, diff_stride,
2639                                    sse_norm_arr, sad_norm_arr);
2640   if (shift) {
2641     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
2642     for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
2643   }
2644   for (int i = 0; i < 4; ++i) {
2645     fprintf(fout, " %g", sse_norm_arr[i]);
2646   }
2647   for (int i = 0; i < 4; ++i) {
2648     fprintf(fout, " %g", sad_norm_arr[i]);
2649   }
2650 
2651   fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
2652 
2653   int model_rate;
2654   int64_t model_dist;
2655   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
2656                                    &model_rate, &model_dist);
2657   const double model_rdcost_norm =
2658       (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
2659   const double model_rate_norm = (double)model_rate / num_samples;
2660   const double model_dist_norm = (double)model_dist / num_samples;
2661   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
2662           model_rdcost_norm);
2663 
2664   double mean;
2665   if (is_cur_buf_hbd(xd)) {
2666     mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
2667                                 pd->dst.stride, bw, bh);
2668   } else {
2669     mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
2670                          bw, bh);
2671   }
2672   mean /= (1 << shift);
2673   float hor_corr, vert_corr;
2674   av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
2675                                   &vert_corr);
2676   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
2677 
2678   double hdist[4] = { 0 }, vdist[4] = { 0 };
2679   get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
2680                                dst_stride, 1, hdist, vdist);
2681   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
2682           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
2683 
2684   if (cpi->sf.inter_mode_rd_model_estimation == 1) {
2685     assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
2686     const int64_t overall_sse = get_sse(cpi, x);
2687     int est_residue_cost = 0;
2688     int64_t est_dist = 0;
2689     get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
2690                       &est_dist);
2691     const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
2692     const double est_dist_norm = (double)est_dist / num_samples;
2693     const double est_rdcost_norm =
2694         (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
2695     fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
2696             est_rdcost_norm);
2697   }
2698 
2699   fprintf(fout, "\n");
2700   fclose(fout);
2701 }
2702 #endif  // CONFIG_COLLECT_RD_STATS >= 2
2703 #endif  // CONFIG_COLLECT_RD_STATS
2704 
model_rd_with_dnn(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)2705 static void model_rd_with_dnn(const AV1_COMP *const cpi,
2706                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
2707                               int plane, int64_t sse, int num_samples,
2708                               int *rate, int64_t *dist) {
2709   const MACROBLOCKD *const xd = &x->e_mbd;
2710   const struct macroblockd_plane *const pd = &xd->plane[plane];
2711   const int log_numpels = num_pels_log2_lookup[plane_bsize];
2712 
2713   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2714   const int q_step = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
2715 
2716   const struct macroblock_plane *const p = &x->plane[plane];
2717   int bw, bh;
2718   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
2719                      &bh);
2720   const int src_stride = p->src.stride;
2721   const uint8_t *const src = p->src.buf;
2722   const int dst_stride = pd->dst.stride;
2723   const uint8_t *const dst = pd->dst.buf;
2724   const int16_t *const src_diff = p->src_diff;
2725   const int diff_stride = block_size_wide[plane_bsize];
2726   const int shift = (xd->bd - 8);
2727 
2728   if (sse == 0) {
2729     if (rate) *rate = 0;
2730     if (dist) *dist = 0;
2731     return;
2732   }
2733   if (plane) {
2734     int model_rate;
2735     int64_t model_dist;
2736     model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, num_samples,
2737                           &model_rate, &model_dist);
2738     if (rate) *rate = model_rate;
2739     if (dist) *dist = model_dist;
2740     return;
2741   }
2742 
2743   aom_clear_system_state();
2744   const double sse_norm = (double)sse / num_samples;
2745 
2746   double sse_norm_arr[4];
2747   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
2748                                    dst_stride, src_diff, diff_stride,
2749                                    sse_norm_arr, NULL);
2750   double mean;
2751   if (is_cur_buf_hbd(xd)) {
2752     mean = get_highbd_diff_mean(src, src_stride, dst, dst_stride, bw, bh);
2753   } else {
2754     mean = get_diff_mean(src, src_stride, dst, dst_stride, bw, bh);
2755   }
2756   if (shift) {
2757     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
2758     mean /= (1 << shift);
2759   }
2760   double sse_norm_sum = 0.0, sse_frac_arr[3];
2761   for (int k = 0; k < 4; ++k) sse_norm_sum += sse_norm_arr[k];
2762   for (int k = 0; k < 3; ++k)
2763     sse_frac_arr[k] =
2764         sse_norm_sum > 0.0 ? sse_norm_arr[k] / sse_norm_sum : 0.25;
2765   const double q_sqr = (double)(q_step * q_step);
2766   const double q_sqr_by_sse_norm = q_sqr / (sse_norm + 1.0);
2767   const double mean_sqr_by_sse_norm = mean * mean / (sse_norm + 1.0);
2768   float hor_corr, vert_corr;
2769   av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
2770                                   &vert_corr);
2771 
2772   float features[NUM_FEATURES_PUSTATS];
2773   features[0] = (float)hor_corr;
2774   features[1] = (float)log_numpels;
2775   features[2] = (float)mean_sqr_by_sse_norm;
2776   features[3] = (float)q_sqr_by_sse_norm;
2777   features[4] = (float)sse_frac_arr[0];
2778   features[5] = (float)sse_frac_arr[1];
2779   features[6] = (float)sse_frac_arr[2];
2780   features[7] = (float)vert_corr;
2781 
2782   float rate_f, dist_by_sse_norm_f;
2783   av1_nn_predict(features, &av1_pustats_dist_nnconfig, &dist_by_sse_norm_f);
2784   av1_nn_predict(features, &av1_pustats_rate_nnconfig, &rate_f);
2785   aom_clear_system_state();
2786   const float dist_f = (float)((double)dist_by_sse_norm_f * (1.0 + sse_norm));
2787   int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
2788   int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5);
2789 
2790   // Check if skip is better
2791   if (rate_i == 0) {
2792     dist_i = sse << 4;
2793   } else if (RDCOST(x->rdmult, rate_i, dist_i) >=
2794              RDCOST(x->rdmult, 0, sse << 4)) {
2795     rate_i = 0;
2796     dist_i = sse << 4;
2797   }
2798 
2799   if (rate) *rate = rate_i;
2800   if (dist) *dist = dist_i;
2801   return;
2802 }
2803 
model_rd_for_sb_with_dnn(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)2804 static void model_rd_for_sb_with_dnn(
2805     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
2806     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
2807     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
2808     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
2809   (void)mi_row;
2810   (void)mi_col;
2811   // Note our transform coeffs are 8 times an orthogonal transform.
2812   // Hence quantizer step is also 8 times. To get effective quantizer
2813   // we need to divide by 8 before sending to modeling function.
2814   const int ref = xd->mi[0]->ref_frame[0];
2815 
2816   int64_t rate_sum = 0;
2817   int64_t dist_sum = 0;
2818   int64_t total_sse = 0;
2819 
2820   for (int plane = plane_from; plane <= plane_to; ++plane) {
2821     struct macroblockd_plane *const pd = &xd->plane[plane];
2822     const BLOCK_SIZE plane_bsize =
2823         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
2824     int64_t dist, sse;
2825     int rate;
2826 
2827     if (x->skip_chroma_rd && plane) continue;
2828 
2829     const struct macroblock_plane *const p = &x->plane[plane];
2830     const int shift = (xd->bd - 8);
2831     int bw, bh;
2832     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
2833                        &bw, &bh);
2834     if (is_cur_buf_hbd(xd)) {
2835       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
2836                            pd->dst.stride, bw, bh);
2837     } else {
2838       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
2839                     bh);
2840     }
2841     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
2842 
2843     model_rd_with_dnn(cpi, x, plane_bsize, plane, sse, bw * bh, &rate, &dist);
2844 
2845     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
2846 
2847     total_sse += sse;
2848     rate_sum += rate;
2849     dist_sum += dist;
2850 
2851     if (plane_rate) plane_rate[plane] = rate;
2852     if (plane_sse) plane_sse[plane] = sse;
2853     if (plane_dist) plane_dist[plane] = dist;
2854   }
2855 
2856   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
2857   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
2858   *out_rate_sum = (int)rate_sum;
2859   *out_dist_sum = dist_sum;
2860 }
2861 
2862 // Fits a surface for rate and distortion using as features:
2863 // log2(sse_norm + 1) and log2(sse_norm/qstep^2)
model_rd_with_surffit(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)2864 static void model_rd_with_surffit(const AV1_COMP *const cpi,
2865                                   const MACROBLOCK *const x,
2866                                   BLOCK_SIZE plane_bsize, int plane,
2867                                   int64_t sse, int num_samples, int *rate,
2868                                   int64_t *dist) {
2869   (void)cpi;
2870   (void)plane_bsize;
2871   const MACROBLOCKD *const xd = &x->e_mbd;
2872   const struct macroblockd_plane *const pd = &xd->plane[plane];
2873   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2874   const int qstep = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
2875   if (sse == 0) {
2876     if (rate) *rate = 0;
2877     if (dist) *dist = 0;
2878     return;
2879   }
2880   aom_clear_system_state();
2881   const double sse_norm = (double)sse / num_samples;
2882   const double qstepsqr = (double)qstep * qstep;
2883   const double xm = log(sse_norm + 1.0) / log(2.0);
2884   const double yl = log(sse_norm / qstepsqr) / log(2.0);
2885   double rate_f, dist_by_sse_norm_f;
2886 
2887   av1_model_rd_surffit(plane_bsize, sse_norm, xm, yl, &rate_f,
2888                        &dist_by_sse_norm_f);
2889 
2890   const double dist_f = dist_by_sse_norm_f * sse_norm;
2891   int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
2892   int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5);
2893   aom_clear_system_state();
2894 
2895   // Check if skip is better
2896   if (rate_i == 0) {
2897     dist_i = sse << 4;
2898   } else if (RDCOST(x->rdmult, rate_i, dist_i) >=
2899              RDCOST(x->rdmult, 0, sse << 4)) {
2900     rate_i = 0;
2901     dist_i = sse << 4;
2902   }
2903 
2904   if (rate) *rate = rate_i;
2905   if (dist) *dist = dist_i;
2906 }
2907 
model_rd_for_sb_with_surffit(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)2908 static void model_rd_for_sb_with_surffit(
2909     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
2910     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
2911     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
2912     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
2913   (void)mi_row;
2914   (void)mi_col;
2915   // Note our transform coeffs are 8 times an orthogonal transform.
2916   // Hence quantizer step is also 8 times. To get effective quantizer
2917   // we need to divide by 8 before sending to modeling function.
2918   const int ref = xd->mi[0]->ref_frame[0];
2919 
2920   int64_t rate_sum = 0;
2921   int64_t dist_sum = 0;
2922   int64_t total_sse = 0;
2923 
2924   for (int plane = plane_from; plane <= plane_to; ++plane) {
2925     struct macroblockd_plane *const pd = &xd->plane[plane];
2926     const BLOCK_SIZE plane_bsize =
2927         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
2928     int64_t dist, sse;
2929     int rate;
2930 
2931     if (x->skip_chroma_rd && plane) continue;
2932 
2933     int bw, bh;
2934     const struct macroblock_plane *const p = &x->plane[plane];
2935     const int shift = (xd->bd - 8);
2936     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
2937                        &bw, &bh);
2938     if (is_cur_buf_hbd(xd)) {
2939       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
2940                            pd->dst.stride, bw, bh);
2941     } else {
2942       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
2943                     bh);
2944     }
2945     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
2946 
2947     model_rd_with_surffit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
2948                           &dist);
2949 
2950     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
2951 
2952     total_sse += sse;
2953     rate_sum += rate;
2954     dist_sum += dist;
2955 
2956     if (plane_rate) plane_rate[plane] = rate;
2957     if (plane_sse) plane_sse[plane] = sse;
2958     if (plane_dist) plane_dist[plane] = dist;
2959   }
2960 
2961   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
2962   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
2963   *out_rate_sum = (int)rate_sum;
2964   *out_dist_sum = dist_sum;
2965 }
2966 
2967 // Fits a curve for rate and distortion using as feature:
2968 // log2(sse_norm/qstep^2)
model_rd_with_curvfit(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)2969 static void model_rd_with_curvfit(const AV1_COMP *const cpi,
2970                                   const MACROBLOCK *const x,
2971                                   BLOCK_SIZE plane_bsize, int plane,
2972                                   int64_t sse, int num_samples, int *rate,
2973                                   int64_t *dist) {
2974   (void)cpi;
2975   (void)plane_bsize;
2976   const MACROBLOCKD *const xd = &x->e_mbd;
2977   const struct macroblockd_plane *const pd = &xd->plane[plane];
2978   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2979   const int qstep = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
2980 
2981   if (sse == 0) {
2982     if (rate) *rate = 0;
2983     if (dist) *dist = 0;
2984     return;
2985   }
2986   aom_clear_system_state();
2987   const double sse_norm = (double)sse / num_samples;
2988   const double qstepsqr = (double)qstep * qstep;
2989   const double xqr = log2(sse_norm / qstepsqr);
2990 
2991   double rate_f, dist_by_sse_norm_f;
2992   av1_model_rd_curvfit(plane_bsize, sse_norm, xqr, &rate_f,
2993                        &dist_by_sse_norm_f);
2994 
2995   const double dist_f = dist_by_sse_norm_f * sse_norm;
2996   int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
2997   int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5);
2998   aom_clear_system_state();
2999 
3000   // Check if skip is better
3001   if (rate_i == 0) {
3002     dist_i = sse << 4;
3003   } else if (RDCOST(x->rdmult, rate_i, dist_i) >=
3004              RDCOST(x->rdmult, 0, sse << 4)) {
3005     rate_i = 0;
3006     dist_i = sse << 4;
3007   }
3008 
3009   if (rate) *rate = rate_i;
3010   if (dist) *dist = dist_i;
3011 }
3012 
model_rd_for_sb_with_curvfit(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)3013 static void model_rd_for_sb_with_curvfit(
3014     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
3015     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
3016     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
3017     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
3018   (void)mi_row;
3019   (void)mi_col;
3020   // Note our transform coeffs are 8 times an orthogonal transform.
3021   // Hence quantizer step is also 8 times. To get effective quantizer
3022   // we need to divide by 8 before sending to modeling function.
3023   const int ref = xd->mi[0]->ref_frame[0];
3024 
3025   int64_t rate_sum = 0;
3026   int64_t dist_sum = 0;
3027   int64_t total_sse = 0;
3028 
3029   for (int plane = plane_from; plane <= plane_to; ++plane) {
3030     struct macroblockd_plane *const pd = &xd->plane[plane];
3031     const BLOCK_SIZE plane_bsize =
3032         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3033     int64_t dist, sse;
3034     int rate;
3035 
3036     if (x->skip_chroma_rd && plane) continue;
3037 
3038     int bw, bh;
3039     const struct macroblock_plane *const p = &x->plane[plane];
3040     const int shift = (xd->bd - 8);
3041     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
3042                        &bw, &bh);
3043 
3044     if (is_cur_buf_hbd(xd)) {
3045       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
3046                            pd->dst.stride, bw, bh);
3047     } else {
3048       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
3049                     bh);
3050     }
3051 
3052     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
3053     model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
3054                           &dist);
3055 
3056     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
3057 
3058     total_sse += sse;
3059     rate_sum += rate;
3060     dist_sum += dist;
3061 
3062     if (plane_rate) plane_rate[plane] = rate;
3063     if (plane_sse) plane_sse[plane] = sse;
3064     if (plane_dist) plane_dist[plane] = dist;
3065   }
3066 
3067   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
3068   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
3069   *out_rate_sum = (int)rate_sum;
3070   *out_dist_sum = dist_sum;
3071 }
3072 
search_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int use_fast_coef_costing,int skip_trellis,int64_t ref_best_rd,RD_STATS * best_rd_stats)3073 static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
3074                                int block, int blk_row, int blk_col,
3075                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
3076                                const TXB_CTX *const txb_ctx,
3077                                FAST_TX_SEARCH_MODE ftxs_mode,
3078                                int use_fast_coef_costing, int skip_trellis,
3079                                int64_t ref_best_rd, RD_STATS *best_rd_stats) {
3080   const AV1_COMMON *cm = &cpi->common;
3081   MACROBLOCKD *xd = &x->e_mbd;
3082   struct macroblockd_plane *const pd = &xd->plane[plane];
3083   MB_MODE_INFO *mbmi = xd->mi[0];
3084   const int is_inter = is_inter_block(mbmi);
3085   int64_t best_rd = INT64_MAX;
3086   uint16_t best_eob = 0;
3087   TX_TYPE best_tx_type = DCT_DCT;
3088   TX_TYPE last_tx_type = TX_TYPES;
3089   const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
3090   // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
3091   // of the best tx_type
3092   DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
3093   tran_low_t *orig_dqcoeff = pd->dqcoeff;
3094   tran_low_t *best_dqcoeff = this_dqcoeff;
3095   const int txk_type_idx =
3096       av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
3097   int perform_block_coeff_opt;
3098   av1_invalid_rd_stats(best_rd_stats);
3099 
3100   TXB_RD_INFO *intra_txb_rd_info = NULL;
3101   uint16_t cur_joint_ctx = 0;
3102   const int mi_row = -xd->mb_to_top_edge >> (3 + MI_SIZE_LOG2);
3103   const int mi_col = -xd->mb_to_left_edge >> (3 + MI_SIZE_LOG2);
3104   const int within_border =
3105       mi_row >= xd->tile.mi_row_start &&
3106       (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
3107       mi_col >= xd->tile.mi_col_start &&
3108       (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
3109   skip_trellis |=
3110       cpi->optimize_seg_arr[mbmi->segment_id] == NO_TRELLIS_OPT ||
3111       cpi->optimize_seg_arr[mbmi->segment_id] == FINAL_PASS_TRELLIS_OPT;
3112   if (within_border && cpi->sf.use_intra_txb_hash && frame_is_intra_only(cm) &&
3113       !is_inter && plane == 0 &&
3114       tx_size_wide[tx_size] == tx_size_high[tx_size]) {
3115     const uint32_t intra_hash =
3116         get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
3117     const int intra_hash_idx =
3118         find_tx_size_rd_info(&x->txb_rd_record_intra, intra_hash);
3119     intra_txb_rd_info = &x->txb_rd_record_intra.tx_rd_info[intra_hash_idx];
3120 
3121     cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
3122     if (intra_txb_rd_info->entropy_context == cur_joint_ctx &&
3123         x->txb_rd_record_intra.tx_rd_info[intra_hash_idx].valid) {
3124       mbmi->txk_type[txk_type_idx] = intra_txb_rd_info->tx_type;
3125       const TX_TYPE ref_tx_type =
3126           av1_get_tx_type(get_plane_type(plane), &x->e_mbd, blk_row, blk_col,
3127                           tx_size, cpi->common.reduced_tx_set_used);
3128       if (ref_tx_type == intra_txb_rd_info->tx_type) {
3129         best_rd_stats->rate = intra_txb_rd_info->rate;
3130         best_rd_stats->dist = intra_txb_rd_info->dist;
3131         best_rd_stats->sse = intra_txb_rd_info->sse;
3132         best_rd_stats->skip = intra_txb_rd_info->eob == 0;
3133         x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
3134         x->plane[plane].txb_entropy_ctx[block] =
3135             intra_txb_rd_info->txb_entropy_ctx;
3136         best_rd = RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->dist);
3137         best_eob = intra_txb_rd_info->eob;
3138         best_tx_type = intra_txb_rd_info->tx_type;
3139         update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
3140                          best_tx_type);
3141         goto RECON_INTRA;
3142       }
3143     }
3144   }
3145 
3146   int rate_cost = 0;
3147   TX_TYPE txk_start = DCT_DCT;
3148   TX_TYPE txk_end = TX_TYPES - 1;
3149   if ((!is_inter && x->use_default_intra_tx_type) ||
3150       (is_inter && x->use_default_inter_tx_type)) {
3151     txk_start = txk_end =
3152         get_default_tx_type(0, xd, tx_size, cpi->is_screen_content_type);
3153   } else if (x->rd_model == LOW_TXFM_RD || x->cb_partition_scan) {
3154     if (plane == 0) txk_end = DCT_DCT;
3155   }
3156 
3157   uint8_t best_txb_ctx = 0;
3158   const TxSetType tx_set_type =
3159       av1_get_ext_tx_set_type(tx_size, is_inter, cm->reduced_tx_set_used);
3160 
3161   TX_TYPE uv_tx_type = DCT_DCT;
3162   if (plane) {
3163     // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
3164     uv_tx_type = txk_start = txk_end =
3165         av1_get_tx_type(get_plane_type(plane), xd, blk_row, blk_col, tx_size,
3166                         cm->reduced_tx_set_used);
3167   }
3168   const uint16_t ext_tx_used_flag = av1_ext_tx_used_flag[tx_set_type];
3169   if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
3170       ext_tx_used_flag == 0x0001 ||
3171       (is_inter && cpi->oxcf.use_inter_dct_only) ||
3172       (!is_inter && cpi->oxcf.use_intra_dct_only)) {
3173     txk_start = txk_end = DCT_DCT;
3174   }
3175   uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
3176   if (txk_start == txk_end) {
3177     allowed_tx_mask = 1 << txk_start;
3178     allowed_tx_mask &= ext_tx_used_flag;
3179   } else if (fast_tx_search) {
3180     allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
3181     allowed_tx_mask &= ext_tx_used_flag;
3182   } else {
3183     assert(plane == 0);
3184     allowed_tx_mask = ext_tx_used_flag;
3185     // !fast_tx_search && txk_end != txk_start && plane == 0
3186     const int do_prune = cpi->sf.tx_type_search.prune_mode > NO_PRUNE;
3187     if (do_prune && is_inter) {
3188       if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE) {
3189         const uint16_t prune =
3190             prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
3191                         cpi->sf.tx_type_search.prune_mode);
3192         allowed_tx_mask &= (~prune);
3193       } else {
3194         allowed_tx_mask &= (~x->tx_search_prune[tx_set_type]);
3195       }
3196     }
3197   }
3198 
3199   if (cpi->oxcf.enable_flip_idtx == 0) {
3200     for (TX_TYPE tx_type = FLIPADST_DCT; tx_type <= H_FLIPADST; ++tx_type) {
3201       allowed_tx_mask &= ~(1 << tx_type);
3202     }
3203   }
3204 
3205   // Need to have at least one transform type allowed.
3206   if (allowed_tx_mask == 0) {
3207     txk_start = txk_end = (plane ? uv_tx_type : DCT_DCT);
3208     allowed_tx_mask = (1 << txk_start);
3209   }
3210 
3211   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
3212   int64_t block_sse = 0;
3213   unsigned int block_mse_q8 = UINT_MAX;
3214   block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, tx_bsize,
3215                               &block_mse_q8);
3216   assert(block_mse_q8 != UINT_MAX);
3217   if (is_cur_buf_hbd(xd)) {
3218     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
3219     block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
3220   }
3221   block_sse *= 16;
3222   // Tranform domain distortion is accurate for higher residuals.
3223   // TODO(any): Experiment with variance and mean based thresholds
3224   int use_transform_domain_distortion =
3225       (cpi->sf.use_transform_domain_distortion > 0) &&
3226       (block_mse_q8 >= cpi->tx_domain_dist_threshold) &&
3227       // Any 64-pt transforms only preserves half the coefficients.
3228       // Therefore transform domain distortion is not valid for these
3229       // transform sizes.
3230       txsize_sqr_up_map[tx_size] != TX_64X64;
3231 #if CONFIG_DIST_8X8
3232   if (x->using_dist_8x8) use_transform_domain_distortion = 0;
3233 #endif
3234   int calc_pixel_domain_distortion_final =
3235       cpi->sf.use_transform_domain_distortion == 1 &&
3236       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD &&
3237       !x->cb_partition_scan;
3238   if (calc_pixel_domain_distortion_final &&
3239       (txk_start == txk_end || allowed_tx_mask == 0x0001))
3240     calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
3241 
3242   const uint16_t *eobs_ptr = x->plane[plane].eobs;
3243 
3244   // Used mse based threshold logic to take decision of R-D of optimization of
3245   // coeffs. For smaller residuals, coeff optimization would be helpful. For
3246   // larger residuals, R-D optimization may not be effective.
3247   // TODO(any): Experiment with variance and mean based thresholds
3248   perform_block_coeff_opt = (block_mse_q8 <= cpi->coeff_opt_dist_threshold);
3249 
3250   for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
3251     if (!(allowed_tx_mask & (1 << tx_type))) continue;
3252     if (plane == 0) mbmi->txk_type[txk_type_idx] = tx_type;
3253     RD_STATS this_rd_stats;
3254     av1_invalid_rd_stats(&this_rd_stats);
3255     if (skip_trellis || (!perform_block_coeff_opt)) {
3256       av1_xform_quant(
3257           cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, tx_type,
3258           USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP);
3259       rate_cost = av1_cost_coeffs(cm, x, plane, block, tx_size, tx_type,
3260                                   txb_ctx, use_fast_coef_costing);
3261     } else {
3262       av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize,
3263                       tx_size, tx_type, AV1_XFORM_QUANT_FP);
3264       if (cpi->sf.optimize_b_precheck && best_rd < INT64_MAX &&
3265           eobs_ptr[block] >= 4) {
3266         // Calculate distortion quickly in transform domain.
3267         dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
3268                              &this_rd_stats.sse);
3269 
3270         const int64_t best_rd_ = AOMMIN(best_rd, ref_best_rd);
3271         const int64_t dist_cost_estimate =
3272             RDCOST(x->rdmult, 0, AOMMIN(this_rd_stats.dist, this_rd_stats.sse));
3273         if (dist_cost_estimate - (dist_cost_estimate >> 3) > best_rd_) continue;
3274       }
3275       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
3276                      cpi->sf.trellis_eob_fast, &rate_cost);
3277     }
3278     if (eobs_ptr[block] == 0) {
3279       // When eob is 0, pixel domain distortion is more efficient and accurate.
3280       this_rd_stats.dist = this_rd_stats.sse = block_sse;
3281     } else if (use_transform_domain_distortion) {
3282       dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
3283                            &this_rd_stats.sse);
3284     } else {
3285       int64_t sse_diff = INT64_MAX;
3286       // high_energy threshold assumes that every pixel within a txfm block
3287       // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
3288       // for 8 bit, then the threshold is scaled based on input bit depth.
3289       const int64_t high_energy_thresh =
3290           ((int64_t)128 * 128 * tx_size_2d[tx_size]) << ((xd->bd - 8) * 2);
3291       const int is_high_energy = (block_sse >= high_energy_thresh);
3292       if (tx_size == TX_64X64 || is_high_energy) {
3293         // Because 3 out 4 quadrants of transform coefficients are forced to
3294         // zero, the inverse transform has a tendency to overflow. sse_diff
3295         // is effectively the energy of those 3 quadrants, here we use it
3296         // to decide if we should do pixel domain distortion. If the energy
3297         // is mostly in first quadrant, then it is unlikely that we have
3298         // overflow issue in inverse transform.
3299         dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
3300                              &this_rd_stats.sse);
3301         sse_diff = block_sse - this_rd_stats.sse;
3302       }
3303       if (tx_size != TX_64X64 || !is_high_energy ||
3304           (sse_diff * 2) < this_rd_stats.sse) {
3305         const int64_t tx_domain_dist = this_rd_stats.dist;
3306         this_rd_stats.dist = dist_block_px_domain(
3307             cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
3308         // For high energy blocks, occasionally, the pixel domain distortion
3309         // can be artificially low due to clamping at reconstruction stage
3310         // even when inverse transform output is hugely different from the
3311         // actual residue.
3312         if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
3313           this_rd_stats.dist = tx_domain_dist;
3314       } else {
3315         this_rd_stats.dist += sse_diff;
3316       }
3317       this_rd_stats.sse = block_sse;
3318     }
3319 
3320     this_rd_stats.rate = rate_cost;
3321 
3322     const int64_t rd =
3323         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3324 
3325     if (rd < best_rd) {
3326       best_rd = rd;
3327       *best_rd_stats = this_rd_stats;
3328       best_tx_type = tx_type;
3329       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
3330       best_eob = x->plane[plane].eobs[block];
3331       last_tx_type = best_tx_type;
3332 
3333       // Swap qcoeff and dqcoeff buffers
3334       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
3335       best_dqcoeff = pd->dqcoeff;
3336       pd->dqcoeff = tmp_dqcoeff;
3337     }
3338 
3339 #if CONFIG_COLLECT_RD_STATS == 1
3340     if (plane == 0) {
3341       PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
3342                               plane_bsize, tx_size, tx_type, rd);
3343     }
3344 #endif  // CONFIG_COLLECT_RD_STATS == 1
3345 
3346 #if COLLECT_TX_SIZE_DATA
3347     // Generate small sample to restrict output size.
3348     static unsigned int seed = 21743;
3349     if (lcg_rand16(&seed) % 200 == 0) {
3350       FILE *fp = NULL;
3351 
3352       if (within_border) {
3353         fp = fopen(av1_tx_size_data_output_file, "a");
3354       }
3355 
3356       if (fp) {
3357         // Transform info and RD
3358         const int txb_w = tx_size_wide[tx_size];
3359         const int txb_h = tx_size_high[tx_size];
3360 
3361         // Residue signal.
3362         const int diff_stride = block_size_wide[plane_bsize];
3363         struct macroblock_plane *const p = &x->plane[plane];
3364         const int16_t *src_diff =
3365             &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
3366 
3367         for (int r = 0; r < txb_h; ++r) {
3368           for (int c = 0; c < txb_w; ++c) {
3369             fprintf(fp, "%d,", src_diff[c]);
3370           }
3371           src_diff += diff_stride;
3372         }
3373 
3374         fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
3375         fprintf(fp, "\n");
3376         fclose(fp);
3377       }
3378     }
3379 #endif  // COLLECT_TX_SIZE_DATA
3380 
3381     if (cpi->sf.adaptive_txb_search_level) {
3382       if ((best_rd - (best_rd >> cpi->sf.adaptive_txb_search_level)) >
3383           ref_best_rd) {
3384         break;
3385       }
3386     }
3387 
3388     // Skip transform type search when we found the block has been quantized to
3389     // all zero and at the same time, it has better rdcost than doing transform.
3390     if (cpi->sf.tx_type_search.skip_tx_search && !best_eob) break;
3391   }
3392 
3393   assert(best_rd != INT64_MAX);
3394 
3395   best_rd_stats->skip = best_eob == 0;
3396   if (plane == 0) {
3397     update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
3398                      best_tx_type);
3399   }
3400   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
3401   x->plane[plane].eobs[block] = best_eob;
3402 
3403   pd->dqcoeff = best_dqcoeff;
3404 
3405   if (calc_pixel_domain_distortion_final && best_eob) {
3406     best_rd_stats->dist = dist_block_px_domain(
3407         cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
3408     best_rd_stats->sse = block_sse;
3409   }
3410 
3411   if (intra_txb_rd_info != NULL) {
3412     intra_txb_rd_info->valid = 1;
3413     intra_txb_rd_info->entropy_context = cur_joint_ctx;
3414     intra_txb_rd_info->rate = best_rd_stats->rate;
3415     intra_txb_rd_info->dist = best_rd_stats->dist;
3416     intra_txb_rd_info->sse = best_rd_stats->sse;
3417     intra_txb_rd_info->eob = best_eob;
3418     intra_txb_rd_info->txb_entropy_ctx = best_txb_ctx;
3419     if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
3420   }
3421 
3422 RECON_INTRA:
3423   if (!is_inter && best_eob &&
3424       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
3425        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
3426     // intra mode needs decoded result such that the next transform block
3427     // can use it for prediction.
3428     // if the last search tx_type is the best tx_type, we don't need to
3429     // do this again
3430     if (best_tx_type != last_tx_type) {
3431       if (skip_trellis) {
3432         av1_xform_quant(
3433             cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3434             best_tx_type,
3435             USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP);
3436       } else {
3437         av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize,
3438                         tx_size, best_tx_type, AV1_XFORM_QUANT_FP);
3439         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
3440                        cpi->sf.trellis_eob_fast, &rate_cost);
3441       }
3442     }
3443 
3444     inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
3445                                    x->plane[plane].eobs[block],
3446                                    cm->reduced_tx_set_used);
3447 
3448     // This may happen because of hash collision. The eob stored in the hash
3449     // table is non-zero, but the real eob is zero. We need to make sure tx_type
3450     // is DCT_DCT in this case.
3451     if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
3452         best_tx_type != DCT_DCT) {
3453       update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
3454                        DCT_DCT);
3455     }
3456   }
3457   pd->dqcoeff = orig_dqcoeff;
3458 
3459   return best_rd;
3460 }
3461 
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)3462 static void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
3463                           BLOCK_SIZE plane_bsize, TX_SIZE tx_size, void *arg) {
3464   struct rdcost_block_args *args = arg;
3465   MACROBLOCK *const x = args->x;
3466   MACROBLOCKD *const xd = &x->e_mbd;
3467   const int is_inter = is_inter_block(xd->mi[0]);
3468   const AV1_COMP *cpi = args->cpi;
3469   ENTROPY_CONTEXT *a = args->t_above + blk_col;
3470   ENTROPY_CONTEXT *l = args->t_left + blk_row;
3471   const AV1_COMMON *cm = &cpi->common;
3472   RD_STATS this_rd_stats;
3473 
3474   av1_init_rd_stats(&this_rd_stats);
3475 
3476   if (args->exit_early) {
3477     args->incomplete_exit = 1;
3478     return;
3479   }
3480 
3481   if (!is_inter) {
3482     av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3483     av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3484   }
3485   TXB_CTX txb_ctx;
3486   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3487   search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3488                   &txb_ctx, args->ftxs_mode, args->use_fast_coef_costing,
3489                   args->skip_trellis, args->best_rd - args->this_rd,
3490                   &this_rd_stats);
3491 
3492   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3493     assert(!is_inter || plane_bsize < BLOCK_8X8);
3494     cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3495   }
3496 
3497 #if CONFIG_RD_DEBUG
3498   av1_update_txb_coeff_cost(&this_rd_stats, plane, tx_size, blk_row, blk_col,
3499                             this_rd_stats.rate);
3500 #endif  // CONFIG_RD_DEBUG
3501   av1_set_txb_context(x, plane, block, tx_size, a, l);
3502 
3503   const int blk_idx =
3504       blk_row * (block_size_wide[plane_bsize] >> tx_size_wide_log2[0]) +
3505       blk_col;
3506 
3507   if (plane == 0)
3508     set_blk_skip(x, plane, blk_idx, x->plane[plane].eobs[block] == 0);
3509   else
3510     set_blk_skip(x, plane, blk_idx, 0);
3511 
3512   const int64_t rd1 = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3513   const int64_t rd2 = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3514 
3515   // TODO(jingning): temporarily enabled only for luma component
3516   const int64_t rd = AOMMIN(rd1, rd2);
3517 
3518   this_rd_stats.skip &= !x->plane[plane].eobs[block];
3519 
3520   av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3521 
3522   args->this_rd += rd;
3523 
3524   if (args->this_rd > args->best_rd) args->exit_early = 1;
3525 }
3526 
txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t this_rd,int plane,BLOCK_SIZE bsize,TX_SIZE tx_size,int use_fast_coef_casting,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3527 static void txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3528                              RD_STATS *rd_stats, int64_t ref_best_rd,
3529                              int64_t this_rd, int plane, BLOCK_SIZE bsize,
3530                              TX_SIZE tx_size, int use_fast_coef_casting,
3531                              FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis) {
3532   MACROBLOCKD *const xd = &x->e_mbd;
3533   const struct macroblockd_plane *const pd = &xd->plane[plane];
3534   struct rdcost_block_args args;
3535   av1_zero(args);
3536   args.x = x;
3537   args.cpi = cpi;
3538   args.best_rd = ref_best_rd;
3539   args.use_fast_coef_costing = use_fast_coef_casting;
3540   args.ftxs_mode = ftxs_mode;
3541   args.this_rd = this_rd;
3542   args.skip_trellis = skip_trellis;
3543   av1_init_rd_stats(&args.rd_stats);
3544 
3545   if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[tx_size] == TX_64X64) {
3546     av1_invalid_rd_stats(rd_stats);
3547     return;
3548   }
3549 
3550   if (plane == 0) xd->mi[0]->tx_size = tx_size;
3551 
3552   av1_get_entropy_contexts(bsize, pd, args.t_above, args.t_left);
3553 
3554   if (args.this_rd > args.best_rd) {
3555     args.exit_early = 1;
3556   }
3557 
3558   av1_foreach_transformed_block_in_plane(xd, bsize, plane, block_rd_txfm,
3559                                          &args);
3560 
3561   MB_MODE_INFO *const mbmi = xd->mi[0];
3562   const int is_inter = is_inter_block(mbmi);
3563   const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3564 
3565   if (invalid_rd) {
3566     av1_invalid_rd_stats(rd_stats);
3567   } else {
3568     *rd_stats = args.rd_stats;
3569   }
3570 }
3571 
tx_size_cost(const AV1_COMMON * const cm,const MACROBLOCK * const x,BLOCK_SIZE bsize,TX_SIZE tx_size)3572 static int tx_size_cost(const AV1_COMMON *const cm, const MACROBLOCK *const x,
3573                         BLOCK_SIZE bsize, TX_SIZE tx_size) {
3574   assert(bsize == x->e_mbd.mi[0]->sb_type);
3575   if (cm->tx_mode != TX_MODE_SELECT || !block_signals_txsize(bsize)) return 0;
3576 
3577   const int32_t tx_size_cat = bsize_to_tx_size_cat(bsize);
3578   const int depth = tx_size_to_depth(tx_size, bsize);
3579   const MACROBLOCKD *const xd = &x->e_mbd;
3580   const int tx_size_ctx = get_tx_size_context(xd);
3581   return x->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
3582 }
3583 
txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3584 static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3585                         RD_STATS *rd_stats, int64_t ref_best_rd, BLOCK_SIZE bs,
3586                         TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
3587                         int skip_trellis) {
3588   const AV1_COMMON *const cm = &cpi->common;
3589   MACROBLOCKD *const xd = &x->e_mbd;
3590   MB_MODE_INFO *const mbmi = xd->mi[0];
3591   int64_t rd = INT64_MAX;
3592   const int skip_ctx = av1_get_skip_context(xd);
3593   int s0, s1;
3594   const int is_inter = is_inter_block(mbmi);
3595   const int tx_select =
3596       cm->tx_mode == TX_MODE_SELECT && block_signals_txsize(mbmi->sb_type);
3597   int ctx = txfm_partition_context(
3598       xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
3599   const int r_tx_size = is_inter ? x->txfm_partition_cost[ctx][0]
3600                                  : tx_size_cost(cm, x, bs, tx_size);
3601 
3602   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
3603 
3604   s0 = x->skip_cost[skip_ctx][0];
3605   s1 = x->skip_cost[skip_ctx][1];
3606 
3607   int64_t skip_rd;
3608   int64_t this_rd;
3609 
3610   if (is_inter) {
3611     skip_rd = RDCOST(x->rdmult, s1, 0);
3612     this_rd = RDCOST(x->rdmult, s0 + r_tx_size * tx_select, 0);
3613   } else {
3614     skip_rd = RDCOST(x->rdmult, s1 + r_tx_size * tx_select, 0);
3615     this_rd = RDCOST(x->rdmult, s0 + r_tx_size * tx_select, 0);
3616   }
3617 
3618   mbmi->tx_size = tx_size;
3619   txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, AOMMIN(this_rd, skip_rd),
3620                    AOM_PLANE_Y, bs, tx_size, cpi->sf.use_fast_coef_costing,
3621                    ftxs_mode, skip_trellis);
3622   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3623 
3624   // rdstats->rate should include all the rate except skip/non-skip cost as the
3625   // same is accounted in the caller functions after rd evaluation of all
3626   // planes. However the decisions should be done after considering the
3627   // skip/non-skip header cost
3628   if (rd_stats->skip) {
3629     if (is_inter) {
3630       rd = RDCOST(x->rdmult, s1, rd_stats->sse);
3631     } else {
3632       rd = RDCOST(x->rdmult, s1 + r_tx_size * tx_select, rd_stats->sse);
3633       rd_stats->rate += r_tx_size * tx_select;
3634     }
3635   } else {
3636     rd = RDCOST(x->rdmult, rd_stats->rate + s0 + r_tx_size * tx_select,
3637                 rd_stats->dist);
3638     rd_stats->rate += r_tx_size * tx_select;
3639   }
3640   if (is_inter && !xd->lossless[xd->mi[0]->segment_id]) {
3641     int64_t temp_skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
3642     if (temp_skip_rd <= rd) {
3643       rd = temp_skip_rd;
3644       rd_stats->rate = 0;
3645       rd_stats->dist = rd_stats->sse;
3646       rd_stats->skip = 1;
3647     }
3648   }
3649 
3650   return rd;
3651 }
3652 
estimate_yrd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bs,MACROBLOCK * x,int64_t ref_best_rd,RD_STATS * rd_stats)3653 static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
3654                                    MACROBLOCK *x, int64_t ref_best_rd,
3655                                    RD_STATS *rd_stats) {
3656   MACROBLOCKD *const xd = &x->e_mbd;
3657   av1_subtract_plane(x, bs, 0);
3658   x->rd_model = LOW_TXFM_RD;
3659   int skip_trellis = cpi->optimize_seg_arr[xd->mi[0]->segment_id] ==
3660                      NO_ESTIMATE_YRD_TRELLIS_OPT;
3661   const int64_t rd =
3662       txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs, max_txsize_rect_lookup[bs],
3663                FTXS_NONE, skip_trellis);
3664   x->rd_model = FULL_TXFM_RD;
3665   if (rd != INT64_MAX) {
3666     const int skip_ctx = av1_get_skip_context(xd);
3667     if (rd_stats->skip) {
3668       const int s1 = x->skip_cost[skip_ctx][1];
3669       rd_stats->rate = s1;
3670     } else {
3671       const int s0 = x->skip_cost[skip_ctx][0];
3672       rd_stats->rate += s0;
3673     }
3674   }
3675   return rd;
3676 }
3677 
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3678 static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
3679                                    RD_STATS *rd_stats, int64_t ref_best_rd,
3680                                    BLOCK_SIZE bs) {
3681   const AV1_COMMON *const cm = &cpi->common;
3682   MACROBLOCKD *const xd = &x->e_mbd;
3683   MB_MODE_INFO *const mbmi = xd->mi[0];
3684   const int is_inter = is_inter_block(mbmi);
3685   mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode);
3686   const TxSetType tx_set_type =
3687       av1_get_ext_tx_set_type(mbmi->tx_size, is_inter, cm->reduced_tx_set_used);
3688   prune_tx(cpi, bs, x, xd, tx_set_type);
3689   const int skip_ctx = av1_get_skip_context(xd);
3690   int s0, s1;
3691 
3692   s0 = x->skip_cost[skip_ctx][0];
3693   s1 = x->skip_cost[skip_ctx][1];
3694 
3695   int64_t skip_rd = RDCOST(x->rdmult, s1, 0);
3696   int64_t this_rd = RDCOST(x->rdmult, s0, 0);
3697 
3698   txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, AOMMIN(this_rd, skip_rd),
3699                    AOM_PLANE_Y, bs, mbmi->tx_size,
3700                    cpi->sf.use_fast_coef_costing, FTXS_NONE, 0);
3701   // Reset the pruning flags.
3702   av1_zero(x->tx_search_prune);
3703   x->tx_split_prune_flag = 0;
3704 }
3705 
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3706 static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
3707                                     RD_STATS *rd_stats, int64_t ref_best_rd,
3708                                     BLOCK_SIZE bs) {
3709   MACROBLOCKD *const xd = &x->e_mbd;
3710   MB_MODE_INFO *const mbmi = xd->mi[0];
3711 
3712   mbmi->tx_size = TX_4X4;
3713   // TODO(any) : Pass this_rd based on skip/non-skip cost
3714   txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
3715                    cpi->sf.use_fast_coef_costing, FTXS_NONE, 0);
3716 }
3717 
bsize_to_num_blk(BLOCK_SIZE bsize)3718 static INLINE int bsize_to_num_blk(BLOCK_SIZE bsize) {
3719   int num_blk = 1 << (num_pels_log2_lookup[bsize] - 2 * tx_size_wide_log2[0]);
3720   return num_blk;
3721 }
3722 
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf)3723 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
3724                                  const SPEED_FEATURES *sf) {
3725   if (sf->tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
3726 
3727   if (sf->tx_size_search_lgr_block) {
3728     if (mi_width > mi_size_wide[BLOCK_64X64] ||
3729         mi_height > mi_size_high[BLOCK_64X64])
3730       return MAX_VARTX_DEPTH;
3731   }
3732 
3733   if (is_inter) {
3734     return (mi_height != mi_width) ? sf->inter_tx_size_search_init_depth_rect
3735                                    : sf->inter_tx_size_search_init_depth_sqr;
3736   } else {
3737     return (mi_height != mi_width) ? sf->intra_tx_size_search_init_depth_rect
3738                                    : sf->intra_tx_size_search_init_depth_sqr;
3739   }
3740 }
3741 
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3742 static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
3743                                         MACROBLOCK *x, RD_STATS *rd_stats,
3744                                         int64_t ref_best_rd, BLOCK_SIZE bs) {
3745   av1_invalid_rd_stats(rd_stats);
3746 
3747   const AV1_COMMON *const cm = &cpi->common;
3748   MACROBLOCKD *const xd = &x->e_mbd;
3749   MB_MODE_INFO *const mbmi = xd->mi[0];
3750   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
3751   const int tx_select = cm->tx_mode == TX_MODE_SELECT;
3752   int start_tx;
3753   int depth, init_depth;
3754 
3755   if (tx_select) {
3756     start_tx = max_rect_tx_size;
3757     init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
3758                                        is_inter_block(mbmi), &cpi->sf);
3759   } else {
3760     const TX_SIZE chosen_tx_size = tx_size_from_tx_mode(bs, cm->tx_mode);
3761     start_tx = chosen_tx_size;
3762     init_depth = MAX_TX_DEPTH;
3763   }
3764 
3765   prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16);
3766 
3767   TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
3768   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
3769   TX_SIZE best_tx_size = max_rect_tx_size;
3770   int64_t best_rd = INT64_MAX;
3771   const int n4 = bsize_to_num_blk(bs);
3772   x->rd_model = FULL_TXFM_RD;
3773   depth = init_depth;
3774   int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
3775   for (int n = start_tx; depth <= MAX_TX_DEPTH;
3776        depth++, n = sub_tx_size_map[n]) {
3777 #if CONFIG_DIST_8X8
3778     if (x->using_dist_8x8) {
3779       if (tx_size_wide[n] < 8 || tx_size_high[n] < 8) continue;
3780     }
3781 #endif
3782     if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[n] == TX_64X64) continue;
3783 
3784     RD_STATS this_rd_stats;
3785     rd[depth] =
3786         txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, n, FTXS_NONE, 0);
3787 
3788     if (rd[depth] < best_rd) {
3789       memcpy(best_txk_type, mbmi->txk_type,
3790              sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
3791       memcpy(best_blk_skip, x->blk_skip, sizeof(best_blk_skip[0]) * n4);
3792       best_tx_size = n;
3793       best_rd = rd[depth];
3794       *rd_stats = this_rd_stats;
3795     }
3796     if (n == TX_4X4) break;
3797     // If we are searching three depths, prune the smallest size depending
3798     // on rd results for the first two depths for low contrast blocks.
3799     if (depth > init_depth && depth != MAX_TX_DEPTH &&
3800         x->source_variance < 256) {
3801       if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
3802     }
3803   }
3804 
3805   if (rd_stats->rate != INT_MAX) {
3806     mbmi->tx_size = best_tx_size;
3807     memcpy(mbmi->txk_type, best_txk_type,
3808            sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
3809     memcpy(x->blk_skip, best_blk_skip, sizeof(best_blk_skip[0]) * n4);
3810   }
3811 
3812   // Reset the pruning flags.
3813   av1_zero(x->tx_search_prune);
3814   x->tx_split_prune_flag = 0;
3815 }
3816 
3817 // origin_threshold * 128 / 100
3818 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
3819   {
3820       64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
3821       68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
3822   },
3823   {
3824       88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
3825       68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
3826   },
3827   {
3828       90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
3829       74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
3830   },
3831 };
3832 
3833 // lookup table for predict_skip_flag
3834 // int max_tx_size = max_txsize_rect_lookup[bsize];
3835 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
3836 //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
3837 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
3838   TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
3839   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
3840   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
3841   TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
3842 };
3843 
3844 // Uses simple features on top of DCT coefficients to quickly predict
3845 // whether optimal RD decision is to skip encoding the residual.
3846 // The sse value is stored in dist.
predict_skip_flag(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)3847 static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
3848                              int reduced_tx_set) {
3849   const int bw = block_size_wide[bsize];
3850   const int bh = block_size_high[bsize];
3851   const MACROBLOCKD *xd = &x->e_mbd;
3852   const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
3853 
3854   *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
3855 
3856   const int64_t mse = *dist / bw / bh;
3857   // Normalized quantizer takes the transform upscaling factor (8 for tx size
3858   // smaller than 32) into account.
3859   const int16_t normalized_dc_q = dc_q >> 3;
3860   const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
3861   // Predict not to skip when mse is larger than threshold.
3862   if (mse > mse_thresh) return 0;
3863 
3864   const int max_tx_size = max_predict_sf_tx_size[bsize];
3865   const int tx_h = tx_size_high[max_tx_size];
3866   const int tx_w = tx_size_wide[max_tx_size];
3867   DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
3868   TxfmParam param;
3869   param.tx_type = DCT_DCT;
3870   param.tx_size = max_tx_size;
3871   param.bd = xd->bd;
3872   param.is_hbd = is_cur_buf_hbd(xd);
3873   param.lossless = 0;
3874   param.tx_set_type = av1_get_ext_tx_set_type(
3875       param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
3876   const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
3877   const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
3878   const int16_t *src_diff = x->plane[0].src_diff;
3879   const int n_coeff = tx_w * tx_h;
3880   const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
3881   const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
3882   const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
3883   for (int row = 0; row < bh; row += tx_h) {
3884     for (int col = 0; col < bw; col += tx_w) {
3885       av1_fwd_txfm(src_diff + col, coefs, bw, &param);
3886       // Operating on TX domain, not pixels; we want the QTX quantizers
3887       const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
3888       if (dc_coef >= dc_thresh) return 0;
3889       for (int i = 1; i < n_coeff; ++i) {
3890         const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
3891         if (ac_coef >= ac_thresh) return 0;
3892       }
3893     }
3894     src_diff += tx_h * bw;
3895   }
3896   return 1;
3897 }
3898 
3899 // Used to set proper context for early termination with skip = 1.
set_skip_flag(MACROBLOCK * x,RD_STATS * rd_stats,int bsize,int64_t dist)3900 static void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats, int bsize,
3901                           int64_t dist) {
3902   MACROBLOCKD *const xd = &x->e_mbd;
3903   MB_MODE_INFO *const mbmi = xd->mi[0];
3904   const int n4 = bsize_to_num_blk(bsize);
3905   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
3906   memset(mbmi->txk_type, DCT_DCT, sizeof(mbmi->txk_type[0]) * TXK_TYPE_BUF_LEN);
3907   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
3908   mbmi->tx_size = tx_size;
3909   for (int i = 0; i < n4; ++i) set_blk_skip(x, 0, i, 1);
3910   rd_stats->skip = 1;
3911   if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
3912   rd_stats->dist = rd_stats->sse = (dist << 4);
3913   // Though decision is to make the block as skip based on luma stats,
3914   // it is possible that block becomes non skip after chroma rd. In addition
3915   // intermediate non skip costs calculated by caller function will be
3916   // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
3917   // accounted). Hence intermediate rate is populated to code the luma tx blks
3918   // as skip, the caller function based on final rd decision (i.e., skip vs
3919   // non-skip) sets the final rate accordingly. Here the rate populated
3920   // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
3921   // size possible) in the current block. Eg: For 128*128 block, rate would be
3922   // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
3923   // block as 'all zeros'
3924   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3925   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3926   av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
3927   ENTROPY_CONTEXT *ta = ctxa;
3928   ENTROPY_CONTEXT *tl = ctxl;
3929   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3930   TXB_CTX txb_ctx;
3931   get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
3932   const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
3933                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3934   rd_stats->rate = zero_blk_rate *
3935                    (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
3936                    (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
3937 }
3938 
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)3939 static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
3940   const int rows = block_size_high[bsize];
3941   const int cols = block_size_wide[bsize];
3942   const int16_t *diff = x->plane[0].src_diff;
3943   const uint32_t hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
3944                                              (uint8_t *)diff, 2 * rows * cols);
3945   return (hash << 5) + bsize;
3946 }
3947 
save_tx_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * tx_rd_record)3948 static void save_tx_rd_info(int n4, uint32_t hash, const MACROBLOCK *const x,
3949                             const RD_STATS *const rd_stats,
3950                             MB_RD_RECORD *tx_rd_record) {
3951   int index;
3952   if (tx_rd_record->num < RD_RECORD_BUFFER_LEN) {
3953     index =
3954         (tx_rd_record->index_start + tx_rd_record->num) % RD_RECORD_BUFFER_LEN;
3955     ++tx_rd_record->num;
3956   } else {
3957     index = tx_rd_record->index_start;
3958     tx_rd_record->index_start =
3959         (tx_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
3960   }
3961   MB_RD_INFO *const tx_rd_info = &tx_rd_record->tx_rd_info[index];
3962   const MACROBLOCKD *const xd = &x->e_mbd;
3963   const MB_MODE_INFO *const mbmi = xd->mi[0];
3964   tx_rd_info->hash_value = hash;
3965   tx_rd_info->tx_size = mbmi->tx_size;
3966   memcpy(tx_rd_info->blk_skip, x->blk_skip,
3967          sizeof(tx_rd_info->blk_skip[0]) * n4);
3968   av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
3969   av1_copy(tx_rd_info->txk_type, mbmi->txk_type);
3970   tx_rd_info->rd_stats = *rd_stats;
3971 }
3972 
fetch_tx_rd_info(int n4,const MB_RD_INFO * const tx_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)3973 static void fetch_tx_rd_info(int n4, const MB_RD_INFO *const tx_rd_info,
3974                              RD_STATS *const rd_stats, MACROBLOCK *const x) {
3975   MACROBLOCKD *const xd = &x->e_mbd;
3976   MB_MODE_INFO *const mbmi = xd->mi[0];
3977   mbmi->tx_size = tx_rd_info->tx_size;
3978   memcpy(x->blk_skip, tx_rd_info->blk_skip,
3979          sizeof(tx_rd_info->blk_skip[0]) * n4);
3980   av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
3981   av1_copy(mbmi->txk_type, tx_rd_info->txk_type);
3982   *rd_stats = tx_rd_info->rd_stats;
3983 }
3984 
find_mb_rd_info(const MB_RD_RECORD * const mb_rd_record,const int64_t ref_best_rd,const uint32_t hash)3985 static INLINE int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
3986                                       const int64_t ref_best_rd,
3987                                       const uint32_t hash) {
3988   int32_t match_index = -1;
3989   if (ref_best_rd != INT64_MAX) {
3990     for (int i = 0; i < mb_rd_record->num; ++i) {
3991       const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
3992       // If there is a match in the tx_rd_record, fetch the RD decision and
3993       // terminate early.
3994       if (mb_rd_record->tx_rd_info[index].hash_value == hash) {
3995         match_index = index;
3996         break;
3997       }
3998     }
3999   }
4000   return match_index;
4001 }
4002 
super_block_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)4003 static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
4004                             RD_STATS *rd_stats, BLOCK_SIZE bs,
4005                             int64_t ref_best_rd) {
4006   MACROBLOCKD *xd = &x->e_mbd;
4007   av1_init_rd_stats(rd_stats);
4008   int is_inter = is_inter_block(xd->mi[0]);
4009   assert(bs == xd->mi[0]->sb_type);
4010 
4011   const int mi_row = -xd->mb_to_top_edge >> (3 + MI_SIZE_LOG2);
4012   const int mi_col = -xd->mb_to_left_edge >> (3 + MI_SIZE_LOG2);
4013 
4014   uint32_t hash = 0;
4015   int32_t match_index = -1;
4016   MB_RD_RECORD *mb_rd_record = NULL;
4017   const int within_border = mi_row >= xd->tile.mi_row_start &&
4018                             (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
4019                             mi_col >= xd->tile.mi_col_start &&
4020                             (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
4021   const int is_mb_rd_hash_enabled =
4022       (within_border && cpi->sf.use_mb_rd_hash && is_inter);
4023   const int n4 = bsize_to_num_blk(bs);
4024   if (is_mb_rd_hash_enabled) {
4025     hash = get_block_residue_hash(x, bs);
4026     mb_rd_record = &x->mb_rd_record;
4027     match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
4028     if (match_index != -1) {
4029       MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
4030       fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
4031       // Reset the pruning flags.
4032       av1_zero(x->tx_search_prune);
4033       x->tx_split_prune_flag = 0;
4034       return;
4035     }
4036   }
4037 
4038   // If we predict that skip is the optimal RD decision - set the respective
4039   // context and terminate early.
4040   int64_t dist;
4041 
4042   if (cpi->sf.tx_type_search.use_skip_flag_prediction && is_inter &&
4043       (!xd->lossless[xd->mi[0]->segment_id]) &&
4044       predict_skip_flag(x, bs, &dist, cpi->common.reduced_tx_set_used)) {
4045     // Populate rdstats as per skip decision
4046     set_skip_flag(x, rd_stats, bs, dist);
4047     // Save the RD search results into tx_rd_record.
4048     if (is_mb_rd_hash_enabled)
4049       save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
4050     // Reset the pruning flags.
4051     av1_zero(x->tx_search_prune);
4052     x->tx_split_prune_flag = 0;
4053     return;
4054   }
4055 
4056   if (xd->lossless[xd->mi[0]->segment_id]) {
4057     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
4058   } else if (cpi->sf.tx_size_search_method == USE_LARGESTALL) {
4059     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
4060   } else {
4061     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
4062   }
4063 
4064   // Save the RD search results into tx_rd_record.
4065   if (is_mb_rd_hash_enabled) {
4066     assert(mb_rd_record != NULL);
4067     save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
4068   }
4069 }
4070 
4071 // Return the rate cost for luma prediction mode info. of intra blocks.
intra_mode_info_cost_y(const AV1_COMP * cpi,const MACROBLOCK * x,const MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mode_cost)4072 static int intra_mode_info_cost_y(const AV1_COMP *cpi, const MACROBLOCK *x,
4073                                   const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
4074                                   int mode_cost) {
4075   int total_rate = mode_cost;
4076   const int use_palette = mbmi->palette_mode_info.palette_size[0] > 0;
4077   const int use_filter_intra = mbmi->filter_intra_mode_info.use_filter_intra;
4078   const int use_intrabc = mbmi->use_intrabc;
4079   // Can only activate one mode.
4080   assert(((mbmi->mode != DC_PRED) + use_palette + use_intrabc +
4081           use_filter_intra) <= 1);
4082   const int try_palette =
4083       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
4084   if (try_palette && mbmi->mode == DC_PRED) {
4085     const MACROBLOCKD *xd = &x->e_mbd;
4086     const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
4087     const int mode_ctx = av1_get_palette_mode_ctx(xd);
4088     total_rate += x->palette_y_mode_cost[bsize_ctx][mode_ctx][use_palette];
4089     if (use_palette) {
4090       const uint8_t *const color_map = xd->plane[0].color_index_map;
4091       int block_width, block_height, rows, cols;
4092       av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
4093                                &cols);
4094       const int plt_size = mbmi->palette_mode_info.palette_size[0];
4095       int palette_mode_cost =
4096           x->palette_y_size_cost[bsize_ctx][plt_size - PALETTE_MIN_SIZE] +
4097           write_uniform_cost(plt_size, color_map[0]);
4098       uint16_t color_cache[2 * PALETTE_MAX_SIZE];
4099       const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
4100       palette_mode_cost +=
4101           av1_palette_color_cost_y(&mbmi->palette_mode_info, color_cache,
4102                                    n_cache, cpi->common.seq_params.bit_depth);
4103       palette_mode_cost +=
4104           av1_cost_color_map(x, 0, bsize, mbmi->tx_size, PALETTE_MAP);
4105       total_rate += palette_mode_cost;
4106     }
4107   }
4108   if (av1_filter_intra_allowed(&cpi->common, mbmi)) {
4109     total_rate += x->filter_intra_cost[mbmi->sb_type][use_filter_intra];
4110     if (use_filter_intra) {
4111       total_rate += x->filter_intra_mode_cost[mbmi->filter_intra_mode_info
4112                                                   .filter_intra_mode];
4113     }
4114   }
4115   if (av1_is_directional_mode(mbmi->mode)) {
4116     if (av1_use_angle_delta(bsize)) {
4117       total_rate += x->angle_delta_cost[mbmi->mode - V_PRED]
4118                                        [MAX_ANGLE_DELTA +
4119                                         mbmi->angle_delta[PLANE_TYPE_Y]];
4120     }
4121   }
4122   if (av1_allow_intrabc(&cpi->common))
4123     total_rate += x->intrabc_cost[use_intrabc];
4124   return total_rate;
4125 }
4126 
4127 // Return the rate cost for chroma prediction mode info. of intra blocks.
intra_mode_info_cost_uv(const AV1_COMP * cpi,const MACROBLOCK * x,const MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mode_cost)4128 static int intra_mode_info_cost_uv(const AV1_COMP *cpi, const MACROBLOCK *x,
4129                                    const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
4130                                    int mode_cost) {
4131   int total_rate = mode_cost;
4132   const int use_palette = mbmi->palette_mode_info.palette_size[1] > 0;
4133   const UV_PREDICTION_MODE mode = mbmi->uv_mode;
4134   // Can only activate one mode.
4135   assert(((mode != UV_DC_PRED) + use_palette + mbmi->use_intrabc) <= 1);
4136 
4137   const int try_palette =
4138       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
4139   if (try_palette && mode == UV_DC_PRED) {
4140     const PALETTE_MODE_INFO *pmi = &mbmi->palette_mode_info;
4141     total_rate +=
4142         x->palette_uv_mode_cost[pmi->palette_size[0] > 0][use_palette];
4143     if (use_palette) {
4144       const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
4145       const int plt_size = pmi->palette_size[1];
4146       const MACROBLOCKD *xd = &x->e_mbd;
4147       const uint8_t *const color_map = xd->plane[1].color_index_map;
4148       int palette_mode_cost =
4149           x->palette_uv_size_cost[bsize_ctx][plt_size - PALETTE_MIN_SIZE] +
4150           write_uniform_cost(plt_size, color_map[0]);
4151       uint16_t color_cache[2 * PALETTE_MAX_SIZE];
4152       const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
4153       palette_mode_cost += av1_palette_color_cost_uv(
4154           pmi, color_cache, n_cache, cpi->common.seq_params.bit_depth);
4155       palette_mode_cost +=
4156           av1_cost_color_map(x, 1, bsize, mbmi->tx_size, PALETTE_MAP);
4157       total_rate += palette_mode_cost;
4158     }
4159   }
4160   if (av1_is_directional_mode(get_uv_mode(mode))) {
4161     if (av1_use_angle_delta(bsize)) {
4162       total_rate +=
4163           x->angle_delta_cost[mode - V_PRED][mbmi->angle_delta[PLANE_TYPE_UV] +
4164                                              MAX_ANGLE_DELTA];
4165     }
4166   }
4167   return total_rate;
4168 }
4169 
conditional_skipintra(PREDICTION_MODE mode,PREDICTION_MODE best_intra_mode)4170 static int conditional_skipintra(PREDICTION_MODE mode,
4171                                  PREDICTION_MODE best_intra_mode) {
4172   if (mode == D113_PRED && best_intra_mode != V_PRED &&
4173       best_intra_mode != D135_PRED)
4174     return 1;
4175   if (mode == D67_PRED && best_intra_mode != V_PRED &&
4176       best_intra_mode != D45_PRED)
4177     return 1;
4178   if (mode == D203_PRED && best_intra_mode != H_PRED &&
4179       best_intra_mode != D45_PRED)
4180     return 1;
4181   if (mode == D157_PRED && best_intra_mode != H_PRED &&
4182       best_intra_mode != D135_PRED)
4183     return 1;
4184   return 0;
4185 }
4186 
4187 // Model based RD estimation for luma intra blocks.
intra_model_yrd(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mode_cost,int mi_row,int mi_col)4188 static int64_t intra_model_yrd(const AV1_COMP *const cpi, MACROBLOCK *const x,
4189                                BLOCK_SIZE bsize, int mode_cost, int mi_row,
4190                                int mi_col) {
4191   const AV1_COMMON *cm = &cpi->common;
4192   MACROBLOCKD *const xd = &x->e_mbd;
4193   MB_MODE_INFO *const mbmi = xd->mi[0];
4194   assert(!is_inter_block(mbmi));
4195   RD_STATS this_rd_stats;
4196   int row, col;
4197   int64_t temp_sse, this_rd;
4198   TX_SIZE tx_size = tx_size_from_tx_mode(bsize, cm->tx_mode);
4199   const int stepr = tx_size_high_unit[tx_size];
4200   const int stepc = tx_size_wide_unit[tx_size];
4201   const int max_blocks_wide = max_block_wide(xd, bsize, 0);
4202   const int max_blocks_high = max_block_high(xd, bsize, 0);
4203   mbmi->tx_size = tx_size;
4204   // Prediction.
4205   for (row = 0; row < max_blocks_high; row += stepr) {
4206     for (col = 0; col < max_blocks_wide; col += stepc) {
4207       av1_predict_intra_block_facade(cm, xd, 0, col, row, tx_size);
4208     }
4209   }
4210   // RD estimation.
4211   model_rd_sb_fn[MODELRD_TYPE_INTRA](
4212       cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &this_rd_stats.rate,
4213       &this_rd_stats.dist, &this_rd_stats.skip, &temp_sse, NULL, NULL, NULL);
4214   if (av1_is_directional_mode(mbmi->mode) && av1_use_angle_delta(bsize)) {
4215     mode_cost +=
4216         x->angle_delta_cost[mbmi->mode - V_PRED]
4217                            [MAX_ANGLE_DELTA + mbmi->angle_delta[PLANE_TYPE_Y]];
4218   }
4219   if (mbmi->mode == DC_PRED &&
4220       av1_filter_intra_allowed_bsize(cm, mbmi->sb_type)) {
4221     if (mbmi->filter_intra_mode_info.use_filter_intra) {
4222       const int mode = mbmi->filter_intra_mode_info.filter_intra_mode;
4223       mode_cost += x->filter_intra_cost[mbmi->sb_type][1] +
4224                    x->filter_intra_mode_cost[mode];
4225     } else {
4226       mode_cost += x->filter_intra_cost[mbmi->sb_type][0];
4227     }
4228   }
4229   this_rd =
4230       RDCOST(x->rdmult, this_rd_stats.rate + mode_cost, this_rd_stats.dist);
4231   return this_rd;
4232 }
4233 
4234 // Extends 'color_map' array from 'orig_width x orig_height' to 'new_width x
4235 // new_height'. Extra rows and columns are filled in by copying last valid
4236 // row/column.
extend_palette_color_map(uint8_t * const color_map,int orig_width,int orig_height,int new_width,int new_height)4237 static void extend_palette_color_map(uint8_t *const color_map, int orig_width,
4238                                      int orig_height, int new_width,
4239                                      int new_height) {
4240   int j;
4241   assert(new_width >= orig_width);
4242   assert(new_height >= orig_height);
4243   if (new_width == orig_width && new_height == orig_height) return;
4244 
4245   for (j = orig_height - 1; j >= 0; --j) {
4246     memmove(color_map + j * new_width, color_map + j * orig_width, orig_width);
4247     // Copy last column to extra columns.
4248     memset(color_map + j * new_width + orig_width,
4249            color_map[j * new_width + orig_width - 1], new_width - orig_width);
4250   }
4251   // Copy last row to extra rows.
4252   for (j = orig_height; j < new_height; ++j) {
4253     memcpy(color_map + j * new_width, color_map + (orig_height - 1) * new_width,
4254            new_width);
4255   }
4256 }
4257 
4258 // Bias toward using colors in the cache.
4259 // TODO(huisu): Try other schemes to improve compression.
optimize_palette_colors(uint16_t * color_cache,int n_cache,int n_colors,int stride,int * centroids)4260 static void optimize_palette_colors(uint16_t *color_cache, int n_cache,
4261                                     int n_colors, int stride, int *centroids) {
4262   if (n_cache <= 0) return;
4263   for (int i = 0; i < n_colors * stride; i += stride) {
4264     int min_diff = abs(centroids[i] - (int)color_cache[0]);
4265     int idx = 0;
4266     for (int j = 1; j < n_cache; ++j) {
4267       const int this_diff = abs(centroids[i] - color_cache[j]);
4268       if (this_diff < min_diff) {
4269         min_diff = this_diff;
4270         idx = j;
4271       }
4272     }
4273     if (min_diff <= 1) centroids[i] = color_cache[idx];
4274   }
4275 }
4276 
4277 // Given the base colors as specified in centroids[], calculate the RD cost
4278 // of palette mode.
palette_rd_y(const AV1_COMP * const cpi,MACROBLOCK * x,MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mi_row,int mi_col,int dc_mode_cost,const int * data,int * centroids,int n,uint16_t * color_cache,int n_cache,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int * rate_overhead,int64_t * distortion,int * skippable,PICK_MODE_CONTEXT * ctx,uint8_t * blk_skip)4279 static void palette_rd_y(const AV1_COMP *const cpi, MACROBLOCK *x,
4280                          MB_MODE_INFO *mbmi, BLOCK_SIZE bsize, int mi_row,
4281                          int mi_col, int dc_mode_cost, const int *data,
4282                          int *centroids, int n, uint16_t *color_cache,
4283                          int n_cache, MB_MODE_INFO *best_mbmi,
4284                          uint8_t *best_palette_color_map, int64_t *best_rd,
4285                          int64_t *best_model_rd, int *rate, int *rate_tokenonly,
4286                          int *rate_overhead, int64_t *distortion,
4287                          int *skippable, PICK_MODE_CONTEXT *ctx,
4288                          uint8_t *blk_skip) {
4289   optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
4290   int k = av1_remove_duplicates(centroids, n);
4291   if (k < PALETTE_MIN_SIZE) {
4292     // Too few unique colors to create a palette. And DC_PRED will work
4293     // well for that case anyway. So skip.
4294     return;
4295   }
4296   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
4297   if (cpi->common.seq_params.use_highbitdepth)
4298     for (int i = 0; i < k; ++i)
4299       pmi->palette_colors[i] = clip_pixel_highbd(
4300           (int)centroids[i], cpi->common.seq_params.bit_depth);
4301   else
4302     for (int i = 0; i < k; ++i)
4303       pmi->palette_colors[i] = clip_pixel(centroids[i]);
4304   pmi->palette_size[0] = k;
4305   MACROBLOCKD *const xd = &x->e_mbd;
4306   uint8_t *const color_map = xd->plane[0].color_index_map;
4307   int block_width, block_height, rows, cols;
4308   av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
4309                            &cols);
4310   av1_calc_indices(data, centroids, color_map, rows * cols, k, 1);
4311   extend_palette_color_map(color_map, cols, rows, block_width, block_height);
4312   const int palette_mode_cost =
4313       intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost);
4314   int64_t this_model_rd =
4315       intra_model_yrd(cpi, x, bsize, palette_mode_cost, mi_row, mi_col);
4316   if (*best_model_rd != INT64_MAX &&
4317       this_model_rd > *best_model_rd + (*best_model_rd >> 1))
4318     return;
4319   if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
4320   RD_STATS tokenonly_rd_stats;
4321   super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
4322   if (tokenonly_rd_stats.rate == INT_MAX) return;
4323   int this_rate = tokenonly_rd_stats.rate + palette_mode_cost;
4324   int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
4325   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
4326     tokenonly_rd_stats.rate -=
4327         tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
4328   }
4329   if (this_rd < *best_rd) {
4330     *best_rd = this_rd;
4331     memcpy(best_palette_color_map, color_map,
4332            block_width * block_height * sizeof(color_map[0]));
4333     *best_mbmi = *mbmi;
4334     memcpy(blk_skip, x->blk_skip, sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
4335     *rate_overhead = this_rate - tokenonly_rd_stats.rate;
4336     if (rate) *rate = this_rate;
4337     if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
4338     if (distortion) *distortion = tokenonly_rd_stats.dist;
4339     if (skippable) *skippable = tokenonly_rd_stats.skip;
4340   }
4341 }
4342 
rd_pick_palette_intra_sby(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int dc_mode_cost,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,PICK_MODE_CONTEXT * ctx,uint8_t * best_blk_skip)4343 static int rd_pick_palette_intra_sby(
4344     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
4345     int mi_col, int dc_mode_cost, MB_MODE_INFO *best_mbmi,
4346     uint8_t *best_palette_color_map, int64_t *best_rd, int64_t *best_model_rd,
4347     int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
4348     PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip) {
4349   int rate_overhead = 0;
4350   MACROBLOCKD *const xd = &x->e_mbd;
4351   MB_MODE_INFO *const mbmi = xd->mi[0];
4352   assert(!is_inter_block(mbmi));
4353   assert(av1_allow_palette(cpi->common.allow_screen_content_tools, bsize));
4354   const SequenceHeader *const seq_params = &cpi->common.seq_params;
4355   int colors, n;
4356   const int src_stride = x->plane[0].src.stride;
4357   const uint8_t *const src = x->plane[0].src.buf;
4358   uint8_t *const color_map = xd->plane[0].color_index_map;
4359   int block_width, block_height, rows, cols;
4360   av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
4361                            &cols);
4362 
4363   int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
4364   if (seq_params->use_highbitdepth)
4365     colors = av1_count_colors_highbd(src, src_stride, rows, cols,
4366                                      seq_params->bit_depth, count_buf);
4367   else
4368     colors = av1_count_colors(src, src_stride, rows, cols, count_buf);
4369   mbmi->filter_intra_mode_info.use_filter_intra = 0;
4370 
4371   if (colors > 1 && colors <= 64) {
4372     int r, c, i;
4373     const int max_itr = 50;
4374     int *const data = x->palette_buffer->kmeans_data_buf;
4375     int centroids[PALETTE_MAX_SIZE];
4376     int lb, ub, val;
4377     uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
4378     if (seq_params->use_highbitdepth)
4379       lb = ub = src16[0];
4380     else
4381       lb = ub = src[0];
4382 
4383     if (seq_params->use_highbitdepth) {
4384       for (r = 0; r < rows; ++r) {
4385         for (c = 0; c < cols; ++c) {
4386           val = src16[r * src_stride + c];
4387           data[r * cols + c] = val;
4388           if (val < lb)
4389             lb = val;
4390           else if (val > ub)
4391             ub = val;
4392         }
4393       }
4394     } else {
4395       for (r = 0; r < rows; ++r) {
4396         for (c = 0; c < cols; ++c) {
4397           val = src[r * src_stride + c];
4398           data[r * cols + c] = val;
4399           if (val < lb)
4400             lb = val;
4401           else if (val > ub)
4402             ub = val;
4403         }
4404       }
4405     }
4406 
4407     mbmi->mode = DC_PRED;
4408     mbmi->filter_intra_mode_info.use_filter_intra = 0;
4409 
4410     uint16_t color_cache[2 * PALETTE_MAX_SIZE];
4411     const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
4412 
4413     // Find the dominant colors, stored in top_colors[].
4414     int top_colors[PALETTE_MAX_SIZE] = { 0 };
4415     for (i = 0; i < AOMMIN(colors, PALETTE_MAX_SIZE); ++i) {
4416       int max_count = 0;
4417       for (int j = 0; j < (1 << seq_params->bit_depth); ++j) {
4418         if (count_buf[j] > max_count) {
4419           max_count = count_buf[j];
4420           top_colors[i] = j;
4421         }
4422       }
4423       assert(max_count > 0);
4424       count_buf[top_colors[i]] = 0;
4425     }
4426 
4427     // Try the dominant colors directly.
4428     // TODO(huisu@google.com): Try to avoid duplicate computation in cases
4429     // where the dominant colors and the k-means results are similar.
4430     for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
4431       for (i = 0; i < n; ++i) centroids[i] = top_colors[i];
4432       palette_rd_y(cpi, x, mbmi, bsize, mi_row, mi_col, dc_mode_cost, data,
4433                    centroids, n, color_cache, n_cache, best_mbmi,
4434                    best_palette_color_map, best_rd, best_model_rd, rate,
4435                    rate_tokenonly, &rate_overhead, distortion, skippable, ctx,
4436                    best_blk_skip);
4437     }
4438 
4439     // K-means clustering.
4440     for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
4441       if (colors == PALETTE_MIN_SIZE) {
4442         // Special case: These colors automatically become the centroids.
4443         assert(colors == n);
4444         assert(colors == 2);
4445         centroids[0] = lb;
4446         centroids[1] = ub;
4447       } else {
4448         for (i = 0; i < n; ++i) {
4449           centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
4450         }
4451         av1_k_means(data, centroids, color_map, rows * cols, n, 1, max_itr);
4452       }
4453       palette_rd_y(cpi, x, mbmi, bsize, mi_row, mi_col, dc_mode_cost, data,
4454                    centroids, n, color_cache, n_cache, best_mbmi,
4455                    best_palette_color_map, best_rd, best_model_rd, rate,
4456                    rate_tokenonly, &rate_overhead, distortion, skippable, ctx,
4457                    best_blk_skip);
4458     }
4459   }
4460 
4461   if (best_mbmi->palette_mode_info.palette_size[0] > 0) {
4462     memcpy(color_map, best_palette_color_map,
4463            block_width * block_height * sizeof(best_palette_color_map[0]));
4464   }
4465   *mbmi = *best_mbmi;
4466   return rate_overhead;
4467 }
4468 
4469 // Return 1 if an filter intra mode is selected; return 0 otherwise.
rd_pick_filter_intra_sby(const AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,int mode_cost,int64_t * best_rd,int64_t * best_model_rd,PICK_MODE_CONTEXT * ctx)4470 static int rd_pick_filter_intra_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
4471                                     int mi_row, int mi_col, int *rate,
4472                                     int *rate_tokenonly, int64_t *distortion,
4473                                     int *skippable, BLOCK_SIZE bsize,
4474                                     int mode_cost, int64_t *best_rd,
4475                                     int64_t *best_model_rd,
4476                                     PICK_MODE_CONTEXT *ctx) {
4477   MACROBLOCKD *const xd = &x->e_mbd;
4478   MB_MODE_INFO *mbmi = xd->mi[0];
4479   int filter_intra_selected_flag = 0;
4480   FILTER_INTRA_MODE mode;
4481   TX_SIZE best_tx_size = TX_8X8;
4482   FILTER_INTRA_MODE_INFO filter_intra_mode_info;
4483   TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
4484   (void)ctx;
4485   av1_zero(filter_intra_mode_info);
4486   mbmi->filter_intra_mode_info.use_filter_intra = 1;
4487   mbmi->mode = DC_PRED;
4488   mbmi->palette_mode_info.palette_size[0] = 0;
4489 
4490   for (mode = 0; mode < FILTER_INTRA_MODES; ++mode) {
4491     int64_t this_rd, this_model_rd;
4492     RD_STATS tokenonly_rd_stats;
4493     mbmi->filter_intra_mode_info.filter_intra_mode = mode;
4494     this_model_rd = intra_model_yrd(cpi, x, bsize, mode_cost, mi_row, mi_col);
4495     if (*best_model_rd != INT64_MAX &&
4496         this_model_rd > *best_model_rd + (*best_model_rd >> 1))
4497       continue;
4498     if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
4499     super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
4500     if (tokenonly_rd_stats.rate == INT_MAX) continue;
4501     const int this_rate =
4502         tokenonly_rd_stats.rate +
4503         intra_mode_info_cost_y(cpi, x, mbmi, bsize, mode_cost);
4504     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
4505 
4506     if (this_rd < *best_rd) {
4507       *best_rd = this_rd;
4508       best_tx_size = mbmi->tx_size;
4509       filter_intra_mode_info = mbmi->filter_intra_mode_info;
4510       memcpy(best_txk_type, mbmi->txk_type,
4511              sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
4512       memcpy(ctx->blk_skip, x->blk_skip,
4513              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
4514       *rate = this_rate;
4515       *rate_tokenonly = tokenonly_rd_stats.rate;
4516       *distortion = tokenonly_rd_stats.dist;
4517       *skippable = tokenonly_rd_stats.skip;
4518       filter_intra_selected_flag = 1;
4519     }
4520   }
4521 
4522   if (filter_intra_selected_flag) {
4523     mbmi->mode = DC_PRED;
4524     mbmi->tx_size = best_tx_size;
4525     mbmi->filter_intra_mode_info = filter_intra_mode_info;
4526     memcpy(mbmi->txk_type, best_txk_type,
4527            sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
4528     return 1;
4529   } else {
4530     return 0;
4531   }
4532 }
4533 
4534 // Run RD calculation with given luma intra prediction angle., and return
4535 // the RD cost. Update the best mode info. if the RD cost is the best so far.
calc_rd_given_intra_angle(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int mode_cost,int64_t best_rd_in,int8_t angle_delta,int max_angle_delta,int * rate,RD_STATS * rd_stats,int * best_angle_delta,TX_SIZE * best_tx_size,int64_t * best_rd,int64_t * best_model_rd,TX_TYPE * best_txk_type,uint8_t * best_blk_skip)4536 static int64_t calc_rd_given_intra_angle(
4537     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
4538     int mi_col, int mode_cost, int64_t best_rd_in, int8_t angle_delta,
4539     int max_angle_delta, int *rate, RD_STATS *rd_stats, int *best_angle_delta,
4540     TX_SIZE *best_tx_size, int64_t *best_rd, int64_t *best_model_rd,
4541     TX_TYPE *best_txk_type, uint8_t *best_blk_skip) {
4542   RD_STATS tokenonly_rd_stats;
4543   int64_t this_rd, this_model_rd;
4544   MB_MODE_INFO *mbmi = x->e_mbd.mi[0];
4545   const int n4 = bsize_to_num_blk(bsize);
4546   assert(!is_inter_block(mbmi));
4547   mbmi->angle_delta[PLANE_TYPE_Y] = angle_delta;
4548   this_model_rd = intra_model_yrd(cpi, x, bsize, mode_cost, mi_row, mi_col);
4549   if (*best_model_rd != INT64_MAX &&
4550       this_model_rd > *best_model_rd + (*best_model_rd >> 1))
4551     return INT64_MAX;
4552   if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
4553   super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd_in);
4554   if (tokenonly_rd_stats.rate == INT_MAX) return INT64_MAX;
4555 
4556   int this_rate =
4557       mode_cost + tokenonly_rd_stats.rate +
4558       x->angle_delta_cost[mbmi->mode - V_PRED][max_angle_delta + angle_delta];
4559   this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
4560 
4561   if (this_rd < *best_rd) {
4562     memcpy(best_txk_type, mbmi->txk_type,
4563            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
4564     memcpy(best_blk_skip, x->blk_skip, sizeof(best_blk_skip[0]) * n4);
4565     *best_rd = this_rd;
4566     *best_angle_delta = mbmi->angle_delta[PLANE_TYPE_Y];
4567     *best_tx_size = mbmi->tx_size;
4568     *rate = this_rate;
4569     rd_stats->rate = tokenonly_rd_stats.rate;
4570     rd_stats->dist = tokenonly_rd_stats.dist;
4571     rd_stats->skip = tokenonly_rd_stats.skip;
4572   }
4573   return this_rd;
4574 }
4575 
4576 // With given luma directional intra prediction mode, pick the best angle delta
4577 // Return the RD cost corresponding to the best angle delta.
rd_pick_intra_angle_sby(const AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,int * rate,RD_STATS * rd_stats,BLOCK_SIZE bsize,int mode_cost,int64_t best_rd,int64_t * best_model_rd)4578 static int64_t rd_pick_intra_angle_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
4579                                        int mi_row, int mi_col, int *rate,
4580                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
4581                                        int mode_cost, int64_t best_rd,
4582                                        int64_t *best_model_rd) {
4583   MB_MODE_INFO *mbmi = x->e_mbd.mi[0];
4584   assert(!is_inter_block(mbmi));
4585 
4586   int best_angle_delta = 0;
4587   int64_t rd_cost[2 * (MAX_ANGLE_DELTA + 2)];
4588   TX_SIZE best_tx_size = mbmi->tx_size;
4589   TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
4590   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
4591 
4592   for (int i = 0; i < 2 * (MAX_ANGLE_DELTA + 2); ++i) rd_cost[i] = INT64_MAX;
4593 
4594   int first_try = 1;
4595   for (int angle_delta = 0; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
4596     for (int i = 0; i < 2; ++i) {
4597       const int64_t best_rd_in =
4598           (best_rd == INT64_MAX) ? INT64_MAX
4599                                  : (best_rd + (best_rd >> (first_try ? 3 : 5)));
4600       const int64_t this_rd = calc_rd_given_intra_angle(
4601           cpi, x, bsize, mi_row, mi_col, mode_cost, best_rd_in,
4602           (1 - 2 * i) * angle_delta, MAX_ANGLE_DELTA, rate, rd_stats,
4603           &best_angle_delta, &best_tx_size, &best_rd, best_model_rd,
4604           best_txk_type, best_blk_skip);
4605       rd_cost[2 * angle_delta + i] = this_rd;
4606       if (first_try && this_rd == INT64_MAX) return best_rd;
4607       first_try = 0;
4608       if (angle_delta == 0) {
4609         rd_cost[1] = this_rd;
4610         break;
4611       }
4612     }
4613   }
4614 
4615   assert(best_rd != INT64_MAX);
4616   for (int angle_delta = 1; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
4617     for (int i = 0; i < 2; ++i) {
4618       int skip_search = 0;
4619       const int64_t rd_thresh = best_rd + (best_rd >> 5);
4620       if (rd_cost[2 * (angle_delta + 1) + i] > rd_thresh &&
4621           rd_cost[2 * (angle_delta - 1) + i] > rd_thresh)
4622         skip_search = 1;
4623       if (!skip_search) {
4624         calc_rd_given_intra_angle(cpi, x, bsize, mi_row, mi_col, mode_cost,
4625                                   best_rd, (1 - 2 * i) * angle_delta,
4626                                   MAX_ANGLE_DELTA, rate, rd_stats,
4627                                   &best_angle_delta, &best_tx_size, &best_rd,
4628                                   best_model_rd, best_txk_type, best_blk_skip);
4629       }
4630     }
4631   }
4632 
4633   if (rd_stats->rate != INT_MAX) {
4634     mbmi->tx_size = best_tx_size;
4635     mbmi->angle_delta[PLANE_TYPE_Y] = best_angle_delta;
4636     memcpy(mbmi->txk_type, best_txk_type,
4637            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
4638     memcpy(x->blk_skip, best_blk_skip,
4639            sizeof(best_blk_skip[0]) * bsize_to_num_blk(bsize));
4640   }
4641   return best_rd;
4642 }
4643 
4644 // Indices are sign, integer, and fractional part of the gradient value
4645 static const uint8_t gradient_to_angle_bin[2][7][16] = {
4646   {
4647       { 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0 },
4648       { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 },
4649       { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
4650       { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
4651       { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
4652       { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4653       { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4654   },
4655   {
4656       { 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4 },
4657       { 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3 },
4658       { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
4659       { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
4660       { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
4661       { 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4662       { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4663   },
4664 };
4665 
4666 /* clang-format off */
4667 static const uint8_t mode_to_angle_bin[INTRA_MODES] = {
4668   0, 2, 6, 0, 4, 3, 5, 7, 1, 0,
4669   0,
4670 };
4671 /* clang-format on */
4672 
get_gradient_hist(const uint8_t * src,int src_stride,int rows,int cols,uint64_t * hist)4673 static void get_gradient_hist(const uint8_t *src, int src_stride, int rows,
4674                               int cols, uint64_t *hist) {
4675   src += src_stride;
4676   for (int r = 1; r < rows; ++r) {
4677     for (int c = 1; c < cols; ++c) {
4678       int dx = src[c] - src[c - 1];
4679       int dy = src[c] - src[c - src_stride];
4680       int index;
4681       const int temp = dx * dx + dy * dy;
4682       if (dy == 0) {
4683         index = 2;
4684       } else {
4685         const int sn = (dx > 0) ^ (dy > 0);
4686         dx = abs(dx);
4687         dy = abs(dy);
4688         const int remd = (dx % dy) * 16 / dy;
4689         const int quot = dx / dy;
4690         index = gradient_to_angle_bin[sn][AOMMIN(quot, 6)][AOMMIN(remd, 15)];
4691       }
4692       hist[index] += temp;
4693     }
4694     src += src_stride;
4695   }
4696 }
4697 
get_highbd_gradient_hist(const uint8_t * src8,int src_stride,int rows,int cols,uint64_t * hist)4698 static void get_highbd_gradient_hist(const uint8_t *src8, int src_stride,
4699                                      int rows, int cols, uint64_t *hist) {
4700   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
4701   src += src_stride;
4702   for (int r = 1; r < rows; ++r) {
4703     for (int c = 1; c < cols; ++c) {
4704       int dx = src[c] - src[c - 1];
4705       int dy = src[c] - src[c - src_stride];
4706       int index;
4707       const int temp = dx * dx + dy * dy;
4708       if (dy == 0) {
4709         index = 2;
4710       } else {
4711         const int sn = (dx > 0) ^ (dy > 0);
4712         dx = abs(dx);
4713         dy = abs(dy);
4714         const int remd = (dx % dy) * 16 / dy;
4715         const int quot = dx / dy;
4716         index = gradient_to_angle_bin[sn][AOMMIN(quot, 6)][AOMMIN(remd, 15)];
4717       }
4718       hist[index] += temp;
4719     }
4720     src += src_stride;
4721   }
4722 }
4723 
angle_estimation(const uint8_t * src,int src_stride,int rows,int cols,BLOCK_SIZE bsize,int is_hbd,uint8_t * directional_mode_skip_mask)4724 static void angle_estimation(const uint8_t *src, int src_stride, int rows,
4725                              int cols, BLOCK_SIZE bsize, int is_hbd,
4726                              uint8_t *directional_mode_skip_mask) {
4727   // Check if angle_delta is used
4728   if (!av1_use_angle_delta(bsize)) return;
4729 
4730   uint64_t hist[DIRECTIONAL_MODES] = { 0 };
4731   if (is_hbd)
4732     get_highbd_gradient_hist(src, src_stride, rows, cols, hist);
4733   else
4734     get_gradient_hist(src, src_stride, rows, cols, hist);
4735 
4736   int i;
4737   uint64_t hist_sum = 0;
4738   for (i = 0; i < DIRECTIONAL_MODES; ++i) hist_sum += hist[i];
4739   for (i = 0; i < INTRA_MODES; ++i) {
4740     if (av1_is_directional_mode(i)) {
4741       const uint8_t angle_bin = mode_to_angle_bin[i];
4742       uint64_t score = 2 * hist[angle_bin];
4743       int weight = 2;
4744       if (angle_bin > 0) {
4745         score += hist[angle_bin - 1];
4746         ++weight;
4747       }
4748       if (angle_bin < DIRECTIONAL_MODES - 1) {
4749         score += hist[angle_bin + 1];
4750         ++weight;
4751       }
4752       const int thresh = 10;
4753       if (score * thresh < hist_sum * weight) directional_mode_skip_mask[i] = 1;
4754     }
4755   }
4756 }
4757 
4758 // Given selected prediction mode, search for the best tx type and size.
intra_block_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,const int * bmode_costs,int64_t * best_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,MB_MODE_INFO * best_mbmi,PICK_MODE_CONTEXT * ctx)4759 static void intra_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
4760                             BLOCK_SIZE bsize, const int *bmode_costs,
4761                             int64_t *best_rd, int *rate, int *rate_tokenonly,
4762                             int64_t *distortion, int *skippable,
4763                             MB_MODE_INFO *best_mbmi, PICK_MODE_CONTEXT *ctx) {
4764   MACROBLOCKD *const xd = &x->e_mbd;
4765   MB_MODE_INFO *const mbmi = xd->mi[0];
4766   RD_STATS rd_stats;
4767   super_block_yrd(cpi, x, &rd_stats, bsize, *best_rd);
4768   if (rd_stats.rate == INT_MAX) return;
4769   int this_rate_tokenonly = rd_stats.rate;
4770   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
4771     // super_block_yrd above includes the cost of the tx_size in the
4772     // tokenonly rate, but for intra blocks, tx_size is always coded
4773     // (prediction granularity), so we account for it in the full rate,
4774     // not the tokenonly rate.
4775     this_rate_tokenonly -= tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
4776   }
4777   const int this_rate =
4778       rd_stats.rate +
4779       intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
4780   const int64_t this_rd = RDCOST(x->rdmult, this_rate, rd_stats.dist);
4781   if (this_rd < *best_rd) {
4782     *best_mbmi = *mbmi;
4783     *best_rd = this_rd;
4784     *rate = this_rate;
4785     *rate_tokenonly = this_rate_tokenonly;
4786     *distortion = rd_stats.dist;
4787     *skippable = rd_stats.skip;
4788     memcpy(ctx->blk_skip, x->blk_skip,
4789            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
4790   }
4791 }
4792 
4793 // This function is used only for intra_only frames
rd_pick_intra_sby_mode(const AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,int64_t best_rd,PICK_MODE_CONTEXT * ctx)4794 static int64_t rd_pick_intra_sby_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
4795                                       int mi_row, int mi_col, int *rate,
4796                                       int *rate_tokenonly, int64_t *distortion,
4797                                       int *skippable, BLOCK_SIZE bsize,
4798                                       int64_t best_rd, PICK_MODE_CONTEXT *ctx) {
4799   MACROBLOCKD *const xd = &x->e_mbd;
4800   MB_MODE_INFO *const mbmi = xd->mi[0];
4801   assert(!is_inter_block(mbmi));
4802   int64_t best_model_rd = INT64_MAX;
4803   const int rows = block_size_high[bsize];
4804   const int cols = block_size_wide[bsize];
4805   int is_directional_mode;
4806   uint8_t directional_mode_skip_mask[INTRA_MODES] = { 0 };
4807   int beat_best_rd = 0;
4808   const int *bmode_costs;
4809   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
4810   const int try_palette =
4811       cpi->oxcf.enable_palette &&
4812       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
4813   uint8_t *best_palette_color_map =
4814       try_palette ? x->palette_buffer->best_palette_color_map : NULL;
4815   const MB_MODE_INFO *above_mi = xd->above_mbmi;
4816   const MB_MODE_INFO *left_mi = xd->left_mbmi;
4817   const PREDICTION_MODE A = av1_above_block_mode(above_mi);
4818   const PREDICTION_MODE L = av1_left_block_mode(left_mi);
4819   const int above_ctx = intra_mode_context[A];
4820   const int left_ctx = intra_mode_context[L];
4821   bmode_costs = x->y_mode_costs[above_ctx][left_ctx];
4822 
4823   mbmi->angle_delta[PLANE_TYPE_Y] = 0;
4824   if (cpi->sf.intra_angle_estimation) {
4825     const int src_stride = x->plane[0].src.stride;
4826     const uint8_t *src = x->plane[0].src.buf;
4827     angle_estimation(src, src_stride, rows, cols, bsize, is_cur_buf_hbd(xd),
4828                      directional_mode_skip_mask);
4829   }
4830   mbmi->filter_intra_mode_info.use_filter_intra = 0;
4831   pmi->palette_size[0] = 0;
4832 
4833   if (cpi->sf.tx_type_search.fast_intra_tx_type_search)
4834     x->use_default_intra_tx_type = 1;
4835   else
4836     x->use_default_intra_tx_type = 0;
4837 
4838   MB_MODE_INFO best_mbmi = *mbmi;
4839   /* Y Search for intra prediction mode */
4840   for (int mode_idx = INTRA_MODE_START; mode_idx < INTRA_MODE_END; ++mode_idx) {
4841     RD_STATS this_rd_stats;
4842     int this_rate, this_rate_tokenonly, s;
4843     int64_t this_distortion, this_rd, this_model_rd;
4844     mbmi->mode = intra_rd_search_mode_order[mode_idx];
4845     if ((!cpi->oxcf.enable_smooth_intra || cpi->sf.disable_smooth_intra) &&
4846         (mbmi->mode == SMOOTH_PRED || mbmi->mode == SMOOTH_H_PRED ||
4847          mbmi->mode == SMOOTH_V_PRED))
4848       continue;
4849     if (!cpi->oxcf.enable_paeth_intra && mbmi->mode == PAETH_PRED) continue;
4850     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
4851     this_model_rd =
4852         intra_model_yrd(cpi, x, bsize, bmode_costs[mbmi->mode], mi_row, mi_col);
4853     if (best_model_rd != INT64_MAX &&
4854         this_model_rd > best_model_rd + (best_model_rd >> 1))
4855       continue;
4856     if (this_model_rd < best_model_rd) best_model_rd = this_model_rd;
4857     is_directional_mode = av1_is_directional_mode(mbmi->mode);
4858     if (is_directional_mode && directional_mode_skip_mask[mbmi->mode]) continue;
4859     if (is_directional_mode && av1_use_angle_delta(bsize) &&
4860         cpi->oxcf.enable_angle_delta) {
4861       this_rd_stats.rate = INT_MAX;
4862       rd_pick_intra_angle_sby(cpi, x, mi_row, mi_col, &this_rate,
4863                               &this_rd_stats, bsize, bmode_costs[mbmi->mode],
4864                               best_rd, &best_model_rd);
4865     } else {
4866       super_block_yrd(cpi, x, &this_rd_stats, bsize, best_rd);
4867     }
4868     this_rate_tokenonly = this_rd_stats.rate;
4869     this_distortion = this_rd_stats.dist;
4870     s = this_rd_stats.skip;
4871 
4872     if (this_rate_tokenonly == INT_MAX) continue;
4873 
4874     if (!xd->lossless[mbmi->segment_id] &&
4875         block_signals_txsize(mbmi->sb_type)) {
4876       // super_block_yrd above includes the cost of the tx_size in the
4877       // tokenonly rate, but for intra blocks, tx_size is always coded
4878       // (prediction granularity), so we account for it in the full rate,
4879       // not the tokenonly rate.
4880       this_rate_tokenonly -=
4881           tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
4882     }
4883     this_rate =
4884         this_rd_stats.rate +
4885         intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
4886     this_rd = RDCOST(x->rdmult, this_rate, this_distortion);
4887     if (this_rd < best_rd) {
4888       best_mbmi = *mbmi;
4889       best_rd = this_rd;
4890       beat_best_rd = 1;
4891       *rate = this_rate;
4892       *rate_tokenonly = this_rate_tokenonly;
4893       *distortion = this_distortion;
4894       *skippable = s;
4895       memcpy(ctx->blk_skip, x->blk_skip,
4896              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
4897     }
4898   }
4899 
4900   if (try_palette) {
4901     rd_pick_palette_intra_sby(
4902         cpi, x, bsize, mi_row, mi_col, bmode_costs[DC_PRED], &best_mbmi,
4903         best_palette_color_map, &best_rd, &best_model_rd, rate, rate_tokenonly,
4904         distortion, skippable, ctx, ctx->blk_skip);
4905   }
4906 
4907   if (beat_best_rd && av1_filter_intra_allowed_bsize(&cpi->common, bsize)) {
4908     if (rd_pick_filter_intra_sby(
4909             cpi, x, mi_row, mi_col, rate, rate_tokenonly, distortion, skippable,
4910             bsize, bmode_costs[DC_PRED], &best_rd, &best_model_rd, ctx)) {
4911       best_mbmi = *mbmi;
4912     }
4913   }
4914 
4915   // If previous searches use only the default tx type, do an extra search for
4916   // the best tx type.
4917   if (x->use_default_intra_tx_type) {
4918     *mbmi = best_mbmi;
4919     x->use_default_intra_tx_type = 0;
4920     intra_block_yrd(cpi, x, bsize, bmode_costs, &best_rd, rate, rate_tokenonly,
4921                     distortion, skippable, &best_mbmi, ctx);
4922   }
4923 
4924   *mbmi = best_mbmi;
4925   return best_rd;
4926 }
4927 
4928 // Return value 0: early termination triggered, no valid rd cost available;
4929 //              1: rd cost values are valid.
super_block_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)4930 static int super_block_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x,
4931                             RD_STATS *rd_stats, BLOCK_SIZE bsize,
4932                             int64_t ref_best_rd) {
4933   MACROBLOCKD *const xd = &x->e_mbd;
4934   MB_MODE_INFO *const mbmi = xd->mi[0];
4935   struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
4936   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
4937   int plane;
4938   int is_cost_valid = 1;
4939   const int is_inter = is_inter_block(mbmi);
4940   int64_t this_rd = 0, skip_rd = 0;
4941   av1_init_rd_stats(rd_stats);
4942 
4943   if (ref_best_rd < 0) is_cost_valid = 0;
4944 
4945   if (x->skip_chroma_rd) return is_cost_valid;
4946 
4947   bsize = scale_chroma_bsize(bsize, pd->subsampling_x, pd->subsampling_y);
4948 
4949   if (is_inter && is_cost_valid) {
4950     for (plane = 1; plane < MAX_MB_PLANE; ++plane)
4951       av1_subtract_plane(x, bsize, plane);
4952   }
4953 
4954   if (is_cost_valid) {
4955     for (plane = 1; plane < MAX_MB_PLANE; ++plane) {
4956       RD_STATS pn_rd_stats;
4957       int64_t chroma_ref_best_rd = ref_best_rd;
4958       // For inter blocks, refined ref_best_rd is used for early exit
4959       // For intra blocks, even though current rd crosses ref_best_rd, early
4960       // exit is not recommended as current rd is used for gating subsequent
4961       // modes as well (say, for angular modes)
4962       // TODO(any): Extend the early exit mechanism for intra modes as well
4963       if (cpi->sf.perform_best_rd_based_gating_for_chroma && is_inter &&
4964           chroma_ref_best_rd != INT64_MAX)
4965         chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_rd);
4966       txfm_rd_in_plane(x, cpi, &pn_rd_stats, chroma_ref_best_rd, 0, plane,
4967                        bsize, uv_tx_size, cpi->sf.use_fast_coef_costing,
4968                        FTXS_NONE, 0);
4969       if (pn_rd_stats.rate == INT_MAX) {
4970         is_cost_valid = 0;
4971         break;
4972       }
4973       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
4974       this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
4975       skip_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
4976       if (AOMMIN(this_rd, skip_rd) > ref_best_rd) {
4977         is_cost_valid = 0;
4978         break;
4979       }
4980     }
4981   }
4982 
4983   if (!is_cost_valid) {
4984     // reset cost value
4985     av1_invalid_rd_stats(rd_stats);
4986   }
4987 
4988   return is_cost_valid;
4989 }
4990 
4991 // Pick transform type for a transform block of tx_size.
tx_type_rd(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int plane,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost,TXB_RD_INFO * rd_info_array)4992 static void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x, TX_SIZE tx_size,
4993                        int blk_row, int blk_col, int plane, int block,
4994                        int plane_bsize, TXB_CTX *txb_ctx, RD_STATS *rd_stats,
4995                        FAST_TX_SEARCH_MODE ftxs_mode, int64_t ref_rdcost,
4996                        TXB_RD_INFO *rd_info_array) {
4997   const struct macroblock_plane *const p = &x->plane[plane];
4998   const uint16_t cur_joint_ctx =
4999       (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
5000   const int txk_type_idx =
5001       av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
5002   // Look up RD and terminate early in case when we've already processed exactly
5003   // the same residual with exactly the same entropy context.
5004   if (rd_info_array != NULL && rd_info_array->valid &&
5005       rd_info_array->entropy_context == cur_joint_ctx) {
5006     if (plane == 0)
5007       x->e_mbd.mi[0]->txk_type[txk_type_idx] = rd_info_array->tx_type;
5008     const TX_TYPE ref_tx_type =
5009         av1_get_tx_type(get_plane_type(plane), &x->e_mbd, blk_row, blk_col,
5010                         tx_size, cpi->common.reduced_tx_set_used);
5011     if (ref_tx_type == rd_info_array->tx_type) {
5012       rd_stats->rate += rd_info_array->rate;
5013       rd_stats->dist += rd_info_array->dist;
5014       rd_stats->sse += rd_info_array->sse;
5015       rd_stats->skip &= rd_info_array->eob == 0;
5016       p->eobs[block] = rd_info_array->eob;
5017       p->txb_entropy_ctx[block] = rd_info_array->txb_entropy_ctx;
5018       return;
5019     }
5020   }
5021 
5022   RD_STATS this_rd_stats;
5023   search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
5024                   txb_ctx, ftxs_mode, 0, 0, ref_rdcost, &this_rd_stats);
5025 
5026   av1_merge_rd_stats(rd_stats, &this_rd_stats);
5027 
5028   // Save RD results for possible reuse in future.
5029   if (rd_info_array != NULL) {
5030     rd_info_array->valid = 1;
5031     rd_info_array->entropy_context = cur_joint_ctx;
5032     rd_info_array->rate = this_rd_stats.rate;
5033     rd_info_array->dist = this_rd_stats.dist;
5034     rd_info_array->sse = this_rd_stats.sse;
5035     rd_info_array->eob = p->eobs[block];
5036     rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
5037     if (plane == 0) {
5038       rd_info_array->tx_type = x->e_mbd.mi[0]->txk_type[txk_type_idx];
5039     }
5040   }
5041 }
5042 
get_mean_and_dev(const int16_t * data,int stride,int bw,int bh,float * mean,float * dev)5043 static void get_mean_and_dev(const int16_t *data, int stride, int bw, int bh,
5044                              float *mean, float *dev) {
5045   int x_sum = 0;
5046   uint64_t x2_sum = 0;
5047   for (int i = 0; i < bh; ++i) {
5048     for (int j = 0; j < bw; ++j) {
5049       const int val = data[j];
5050       x_sum += val;
5051       x2_sum += val * val;
5052     }
5053     data += stride;
5054   }
5055 
5056   const int num = bw * bh;
5057   const float e_x = (float)x_sum / num;
5058   const float e_x2 = (float)((double)x2_sum / num);
5059   const float diff = e_x2 - e_x * e_x;
5060   *dev = (diff > 0) ? sqrtf(diff) : 0;
5061   *mean = e_x;
5062 }
5063 
get_mean_and_dev_float(const float * data,int stride,int bw,int bh,float * mean,float * dev)5064 static void get_mean_and_dev_float(const float *data, int stride, int bw,
5065                                    int bh, float *mean, float *dev) {
5066   float x_sum = 0;
5067   float x2_sum = 0;
5068   for (int i = 0; i < bh; ++i) {
5069     for (int j = 0; j < bw; ++j) {
5070       const float val = data[j];
5071       x_sum += val;
5072       x2_sum += val * val;
5073     }
5074     data += stride;
5075   }
5076 
5077   const int num = bw * bh;
5078   const float e_x = x_sum / num;
5079   const float e_x2 = x2_sum / num;
5080   const float diff = e_x2 - e_x * e_x;
5081   *dev = (diff > 0) ? sqrtf(diff) : 0;
5082   *mean = e_x;
5083 }
5084 
5085 // Feature used by the model to predict tx split: the mean and standard
5086 // deviation values of the block and sub-blocks.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,int levels,float * feature)5087 static void get_mean_dev_features(const int16_t *data, int stride, int bw,
5088                                   int bh, int levels, float *feature) {
5089   int feature_idx = 0;
5090   int width = bw;
5091   int height = bh;
5092   const int16_t *const data_ptr = &data[0];
5093   for (int lv = 0; lv < levels; ++lv) {
5094     if (width < 2 || height < 2) break;
5095     float mean_buf[16];
5096     float dev_buf[16];
5097     int blk_idx = 0;
5098     for (int row = 0; row < bh; row += height) {
5099       for (int col = 0; col < bw; col += width) {
5100         float mean, dev;
5101         get_mean_and_dev(data_ptr + row * stride + col, stride, width, height,
5102                          &mean, &dev);
5103         feature[feature_idx++] = mean;
5104         feature[feature_idx++] = dev;
5105         mean_buf[blk_idx] = mean;
5106         dev_buf[blk_idx++] = dev;
5107       }
5108     }
5109     if (blk_idx > 1) {
5110       float mean, dev;
5111       // Deviation of means.
5112       get_mean_and_dev_float(mean_buf, 1, 1, blk_idx, &mean, &dev);
5113       feature[feature_idx++] = dev;
5114       // Mean of deviations.
5115       get_mean_and_dev_float(dev_buf, 1, 1, blk_idx, &mean, &dev);
5116       feature[feature_idx++] = mean;
5117     }
5118     // Reduce the block size when proceeding to the next level.
5119     if (height == width) {
5120       height = height >> 1;
5121       width = width >> 1;
5122     } else if (height > width) {
5123       height = height >> 1;
5124     } else {
5125       width = width >> 1;
5126     }
5127   }
5128 }
5129 
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)5130 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
5131                                int blk_col, TX_SIZE tx_size) {
5132   const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
5133   if (!nn_config) return -1;
5134 
5135   const int diff_stride = block_size_wide[bsize];
5136   const int16_t *diff =
5137       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
5138   const int bw = tx_size_wide[tx_size];
5139   const int bh = tx_size_high[tx_size];
5140   aom_clear_system_state();
5141 
5142   float features[64] = { 0.0f };
5143   get_mean_dev_features(diff, diff_stride, bw, bh, 2, features);
5144 
5145   float score = 0.0f;
5146   av1_nn_predict(features, nn_config, &score);
5147   aom_clear_system_state();
5148   if (score > 8.0f) return 100;
5149   if (score < -8.0f) return 0;
5150   score = 1.0f / (1.0f + (float)exp(-score));
5151   return (int)(score * 100);
5152 }
5153 
5154 typedef struct {
5155   int64_t rd;
5156   int txb_entropy_ctx;
5157   TX_TYPE tx_type;
5158 } TxCandidateInfo;
5159 
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,TxCandidateInfo * no_split)5160 static void try_tx_block_no_split(
5161     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
5162     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
5163     const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
5164     int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
5165     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
5166     TxCandidateInfo *no_split) {
5167   MACROBLOCKD *const xd = &x->e_mbd;
5168   MB_MODE_INFO *const mbmi = xd->mi[0];
5169   struct macroblock_plane *const p = &x->plane[0];
5170   const int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
5171 
5172   no_split->rd = INT64_MAX;
5173   no_split->txb_entropy_ctx = 0;
5174   no_split->tx_type = TX_TYPES;
5175 
5176   const ENTROPY_CONTEXT *const pta = ta + blk_col;
5177   const ENTROPY_CONTEXT *const ptl = tl + blk_row;
5178 
5179   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
5180   TXB_CTX txb_ctx;
5181   get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
5182   const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
5183                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
5184 
5185   rd_stats->ref_rdcost = ref_best_rd;
5186   rd_stats->zero_rate = zero_blk_rate;
5187   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
5188   mbmi->inter_tx_size[index] = tx_size;
5189   tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize, &txb_ctx,
5190              rd_stats, ftxs_mode, ref_best_rd,
5191              rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
5192   assert(rd_stats->rate < INT_MAX);
5193 
5194   if ((RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
5195            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
5196        rd_stats->skip == 1) &&
5197       !xd->lossless[mbmi->segment_id]) {
5198 #if CONFIG_RD_DEBUG
5199     av1_update_txb_coeff_cost(rd_stats, 0, tx_size, blk_row, blk_col,
5200                               zero_blk_rate - rd_stats->rate);
5201 #endif  // CONFIG_RD_DEBUG
5202     rd_stats->rate = zero_blk_rate;
5203     rd_stats->dist = rd_stats->sse;
5204     rd_stats->skip = 1;
5205     set_blk_skip(x, 0, blk_row * bw + blk_col, 1);
5206     p->eobs[block] = 0;
5207     update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
5208                      DCT_DCT);
5209   } else {
5210     set_blk_skip(x, 0, blk_row * bw + blk_col, 0);
5211     rd_stats->skip = 0;
5212   }
5213 
5214   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
5215     rd_stats->rate += x->txfm_partition_cost[txfm_partition_ctx][0];
5216 
5217   no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
5218   no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
5219   const int txk_type_idx =
5220       av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
5221   no_split->tx_type = mbmi->txk_type[txk_type_idx];
5222 }
5223 
5224 static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
5225                             int blk_col, int block, TX_SIZE tx_size, int depth,
5226                             BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
5227                             ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above,
5228                             TXFM_CONTEXT *tx_left, RD_STATS *rd_stats,
5229                             int64_t prev_level_rd, int64_t ref_best_rd,
5230                             int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
5231                             TXB_RD_INFO_NODE *rd_info_node);
5232 
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,RD_STATS * split_rd_stats,int64_t * split_rd)5233 static void try_tx_block_split(
5234     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
5235     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
5236     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
5237     int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
5238     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
5239     RD_STATS *split_rd_stats, int64_t *split_rd) {
5240   assert(tx_size < TX_SIZES_ALL);
5241   MACROBLOCKD *const xd = &x->e_mbd;
5242   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
5243   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
5244   const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
5245   const int bsw = tx_size_wide_unit[sub_txs];
5246   const int bsh = tx_size_high_unit[sub_txs];
5247   const int sub_step = bsw * bsh;
5248   const int nblks =
5249       (tx_size_high_unit[tx_size] / bsh) * (tx_size_wide_unit[tx_size] / bsw);
5250   assert(nblks > 0);
5251   int blk_idx = 0;
5252   int64_t tmp_rd = 0;
5253   *split_rd = INT64_MAX;
5254   split_rd_stats->rate = x->txfm_partition_cost[txfm_partition_ctx][1];
5255 
5256   for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
5257     for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) {
5258       assert(blk_idx < 4);
5259       const int offsetr = blk_row + r;
5260       const int offsetc = blk_col + c;
5261       if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
5262 
5263       RD_STATS this_rd_stats;
5264       int this_cost_valid = 1;
5265       select_tx_block(
5266           cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
5267           tl, tx_above, tx_left, &this_rd_stats, no_split_rd / nblks,
5268           ref_best_rd - tmp_rd, &this_cost_valid, ftxs_mode,
5269           (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
5270       if (!this_cost_valid) return;
5271       av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
5272       tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
5273       if (no_split_rd < tmp_rd) return;
5274       block += sub_step;
5275     }
5276   }
5277 
5278   *split_rd = tmp_rd;
5279 }
5280 
5281 // Search for the best tx partition/type for a given luma block.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t prev_level_rd,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node)5282 static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
5283                             int blk_col, int block, TX_SIZE tx_size, int depth,
5284                             BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
5285                             ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above,
5286                             TXFM_CONTEXT *tx_left, RD_STATS *rd_stats,
5287                             int64_t prev_level_rd, int64_t ref_best_rd,
5288                             int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
5289                             TXB_RD_INFO_NODE *rd_info_node) {
5290   assert(tx_size < TX_SIZES_ALL);
5291   av1_init_rd_stats(rd_stats);
5292   if (ref_best_rd < 0) {
5293     *is_cost_valid = 0;
5294     return;
5295   }
5296 
5297   MACROBLOCKD *const xd = &x->e_mbd;
5298   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
5299   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
5300   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
5301 
5302   const int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
5303   MB_MODE_INFO *const mbmi = xd->mi[0];
5304   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
5305                                          mbmi->sb_type, tx_size);
5306   struct macroblock_plane *const p = &x->plane[0];
5307 
5308   const int try_no_split =
5309       cpi->oxcf.enable_tx64 || txsize_sqr_up_map[tx_size] != TX_64X64;
5310   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
5311 #if CONFIG_DIST_8X8
5312   if (x->using_dist_8x8)
5313     try_split &= tx_size_wide[tx_size] >= 16 && tx_size_high[tx_size] >= 16;
5314 #endif
5315   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
5316 
5317   // TX no split
5318   if (try_no_split) {
5319     try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
5320                           plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
5321                           ftxs_mode, rd_info_node, &no_split);
5322 
5323     if (cpi->sf.adaptive_txb_search_level &&
5324         (no_split.rd -
5325          (no_split.rd >> (1 + cpi->sf.adaptive_txb_search_level))) >
5326             ref_best_rd) {
5327       *is_cost_valid = 0;
5328       return;
5329     }
5330 
5331     if (cpi->sf.txb_split_cap) {
5332       if (p->eobs[block] == 0) try_split = 0;
5333     }
5334 
5335     if (cpi->sf.adaptive_txb_search_level &&
5336         (no_split.rd -
5337          (no_split.rd >> (2 + cpi->sf.adaptive_txb_search_level))) >
5338             prev_level_rd) {
5339       try_split = 0;
5340     }
5341   }
5342 
5343   if (x->e_mbd.bd == 8 && !x->cb_partition_scan && try_split) {
5344     const int threshold = cpi->sf.tx_type_search.ml_tx_split_thresh;
5345     if (threshold >= 0) {
5346       const int split_score =
5347           ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
5348       if (split_score >= 0 && split_score < threshold) try_split = 0;
5349     }
5350   }
5351 
5352   // TX split
5353   int64_t split_rd = INT64_MAX;
5354   RD_STATS split_rd_stats;
5355   av1_init_rd_stats(&split_rd_stats);
5356   if (try_split) {
5357     try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
5358                        plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
5359                        AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
5360                        rd_info_node, &split_rd_stats, &split_rd);
5361   }
5362 
5363   if (no_split.rd < split_rd) {
5364     ENTROPY_CONTEXT *pta = ta + blk_col;
5365     ENTROPY_CONTEXT *ptl = tl + blk_row;
5366     const TX_SIZE tx_size_selected = tx_size;
5367     p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
5368     av1_set_txb_context(x, 0, block, tx_size_selected, pta, ptl);
5369     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
5370                           tx_size);
5371     for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
5372       for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
5373         const int index =
5374             av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
5375         mbmi->inter_tx_size[index] = tx_size_selected;
5376       }
5377     }
5378     mbmi->tx_size = tx_size_selected;
5379     update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
5380                      no_split.tx_type);
5381     set_blk_skip(x, 0, blk_row * bw + blk_col, rd_stats->skip);
5382   } else {
5383     *rd_stats = split_rd_stats;
5384     if (split_rd == INT64_MAX) *is_cost_valid = 0;
5385   }
5386 }
5387 
select_tx_size_and_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,TXB_RD_INFO_NODE * rd_info_tree)5388 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
5389                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
5390                                        int64_t ref_best_rd,
5391                                        TXB_RD_INFO_NODE *rd_info_tree) {
5392   MACROBLOCKD *const xd = &x->e_mbd;
5393   assert(is_inter_block(xd->mi[0]));
5394 
5395   // TODO(debargha): enable this as a speed feature where the
5396   // select_inter_block_yrd() function above will use a simplified search
5397   // such as not using full optimize, but the inter_block_yrd() function
5398   // will use more complex search given that the transform partitions have
5399   // already been decided.
5400 
5401   const int fast_tx_search = cpi->sf.tx_size_search_method > USE_FULL_RD;
5402   int64_t rd_thresh = ref_best_rd;
5403   if (fast_tx_search && rd_thresh < INT64_MAX) {
5404     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
5405   }
5406   assert(rd_thresh > 0);
5407 
5408   const FAST_TX_SEARCH_MODE ftxs_mode =
5409       fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
5410   const struct macroblockd_plane *const pd = &xd->plane[0];
5411   const BLOCK_SIZE plane_bsize =
5412       get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
5413   const int mi_width = mi_size_wide[plane_bsize];
5414   const int mi_height = mi_size_high[plane_bsize];
5415   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
5416   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
5417   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
5418   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
5419   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
5420   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
5421   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
5422 
5423   const int skip_ctx = av1_get_skip_context(xd);
5424   const int s0 = x->skip_cost[skip_ctx][0];
5425   const int s1 = x->skip_cost[skip_ctx][1];
5426   const int init_depth =
5427       get_search_init_depth(mi_width, mi_height, 1, &cpi->sf);
5428   const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
5429   const int bh = tx_size_high_unit[max_tx_size];
5430   const int bw = tx_size_wide_unit[max_tx_size];
5431   const int step = bw * bh;
5432   int64_t skip_rd = RDCOST(x->rdmult, s1, 0);
5433   int64_t this_rd = RDCOST(x->rdmult, s0, 0);
5434   int block = 0;
5435 
5436   av1_init_rd_stats(rd_stats);
5437   for (int idy = 0; idy < mi_height; idy += bh) {
5438     for (int idx = 0; idx < mi_width; idx += bw) {
5439       const int64_t best_rd_sofar =
5440           (rd_thresh == INT64_MAX) ? INT64_MAX
5441                                    : (rd_thresh - (AOMMIN(skip_rd, this_rd)));
5442       int is_cost_valid = 1;
5443       RD_STATS pn_rd_stats;
5444       select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth,
5445                       plane_bsize, ctxa, ctxl, tx_above, tx_left, &pn_rd_stats,
5446                       INT64_MAX, best_rd_sofar, &is_cost_valid, ftxs_mode,
5447                       rd_info_tree);
5448       if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
5449         av1_invalid_rd_stats(rd_stats);
5450         return INT64_MAX;
5451       }
5452       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
5453       skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
5454       this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
5455       block += step;
5456       if (rd_info_tree != NULL) rd_info_tree += 1;
5457     }
5458   }
5459 
5460   if (skip_rd <= this_rd) {
5461     rd_stats->skip = 1;
5462   } else {
5463     rd_stats->skip = 0;
5464   }
5465 
5466   if (rd_stats->rate == INT_MAX) return INT64_MAX;
5467 
5468   // If fast_tx_search is true, only DCT and 1D DCT were tested in
5469   // select_inter_block_yrd() above. Do a better search for tx type with
5470   // tx sizes already decided.
5471   if (fast_tx_search) {
5472     if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
5473       return INT64_MAX;
5474   }
5475 
5476   int64_t rd;
5477   if (rd_stats->skip) {
5478     rd = RDCOST(x->rdmult, s1, rd_stats->sse);
5479   } else {
5480     rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
5481     if (!xd->lossless[xd->mi[0]->segment_id])
5482       rd = AOMMIN(rd, RDCOST(x->rdmult, s1, rd_stats->sse));
5483   }
5484 
5485   return rd;
5486 }
5487 
5488 // Finds rd cost for a y block, given the transform size partitions
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)5489 static void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
5490                          int blk_col, int block, TX_SIZE tx_size,
5491                          BLOCK_SIZE plane_bsize, int depth,
5492                          ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
5493                          TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
5494                          int64_t ref_best_rd, RD_STATS *rd_stats,
5495                          FAST_TX_SEARCH_MODE ftxs_mode) {
5496   MACROBLOCKD *const xd = &x->e_mbd;
5497   MB_MODE_INFO *const mbmi = xd->mi[0];
5498   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
5499   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
5500 
5501   assert(tx_size < TX_SIZES_ALL);
5502 
5503   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
5504 
5505   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
5506       plane_bsize, blk_row, blk_col)];
5507 
5508   int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
5509                                    mbmi->sb_type, tx_size);
5510 
5511   av1_init_rd_stats(rd_stats);
5512   if (tx_size == plane_tx_size) {
5513     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
5514     ENTROPY_CONTEXT *tl = left_ctx + blk_row;
5515     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
5516     TXB_CTX txb_ctx;
5517     get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
5518 
5519     const int zero_blk_rate = x->coeff_costs[txs_ctx][get_plane_type(0)]
5520                                   .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
5521     rd_stats->zero_rate = zero_blk_rate;
5522     rd_stats->ref_rdcost = ref_best_rd;
5523     tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize,
5524                &txb_ctx, rd_stats, ftxs_mode, ref_best_rd, NULL);
5525     const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
5526     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
5527             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
5528         rd_stats->skip == 1) {
5529       rd_stats->rate = zero_blk_rate;
5530       rd_stats->dist = rd_stats->sse;
5531       rd_stats->skip = 1;
5532       set_blk_skip(x, 0, blk_row * mi_width + blk_col, 1);
5533       x->plane[0].eobs[block] = 0;
5534       x->plane[0].txb_entropy_ctx[block] = 0;
5535       update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
5536                        DCT_DCT);
5537     } else {
5538       rd_stats->skip = 0;
5539       set_blk_skip(x, 0, blk_row * mi_width + blk_col, 0);
5540     }
5541     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
5542       rd_stats->rate += x->txfm_partition_cost[ctx][0];
5543     av1_set_txb_context(x, 0, block, tx_size, ta, tl);
5544     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
5545                           tx_size);
5546   } else {
5547     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
5548     const int bsw = tx_size_wide_unit[sub_txs];
5549     const int bsh = tx_size_high_unit[sub_txs];
5550     const int step = bsh * bsw;
5551     RD_STATS pn_rd_stats;
5552     int64_t this_rd = 0;
5553     assert(bsw > 0 && bsh > 0);
5554 
5555     for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
5556       for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
5557         const int offsetr = blk_row + row;
5558         const int offsetc = blk_col + col;
5559 
5560         if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
5561 
5562         av1_init_rd_stats(&pn_rd_stats);
5563         tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
5564                      depth + 1, above_ctx, left_ctx, tx_above, tx_left,
5565                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
5566         if (pn_rd_stats.rate == INT_MAX) {
5567           av1_invalid_rd_stats(rd_stats);
5568           return;
5569         }
5570         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
5571         this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
5572         block += step;
5573       }
5574     }
5575 
5576     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
5577       rd_stats->rate += x->txfm_partition_cost[ctx][1];
5578   }
5579 }
5580 
5581 // Return value 0: early termination triggered, no valid rd cost available;
5582 //              1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)5583 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
5584                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
5585                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
5586   MACROBLOCKD *const xd = &x->e_mbd;
5587   int is_cost_valid = 1;
5588   int64_t this_rd = 0;
5589 
5590   if (ref_best_rd < 0) is_cost_valid = 0;
5591 
5592   av1_init_rd_stats(rd_stats);
5593 
5594   if (is_cost_valid) {
5595     const struct macroblockd_plane *const pd = &xd->plane[0];
5596     const BLOCK_SIZE plane_bsize =
5597         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
5598     const int mi_width = mi_size_wide[plane_bsize];
5599     const int mi_height = mi_size_high[plane_bsize];
5600     const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, plane_bsize, 0);
5601     const int bh = tx_size_high_unit[max_tx_size];
5602     const int bw = tx_size_wide_unit[max_tx_size];
5603     const int init_depth =
5604         get_search_init_depth(mi_width, mi_height, 1, &cpi->sf);
5605     int idx, idy;
5606     int block = 0;
5607     int step = tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
5608     ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
5609     ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
5610     TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
5611     TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
5612     RD_STATS pn_rd_stats;
5613 
5614     av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
5615     memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
5616     memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
5617 
5618     for (idy = 0; idy < mi_height; idy += bh) {
5619       for (idx = 0; idx < mi_width; idx += bw) {
5620         av1_init_rd_stats(&pn_rd_stats);
5621         tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, plane_bsize,
5622                      init_depth, ctxa, ctxl, tx_above, tx_left,
5623                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
5624         if (pn_rd_stats.rate == INT_MAX) {
5625           av1_invalid_rd_stats(rd_stats);
5626           return 0;
5627         }
5628         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
5629         this_rd +=
5630             AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
5631                    RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
5632         block += step;
5633       }
5634     }
5635   }
5636 
5637   const int skip_ctx = av1_get_skip_context(xd);
5638   const int s0 = x->skip_cost[skip_ctx][0];
5639   const int s1 = x->skip_cost[skip_ctx][1];
5640   int64_t skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
5641   this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
5642   if (skip_rd < this_rd) {
5643     this_rd = skip_rd;
5644     rd_stats->rate = 0;
5645     rd_stats->dist = rd_stats->sse;
5646     rd_stats->skip = 1;
5647   }
5648   if (this_rd > ref_best_rd) is_cost_valid = 0;
5649 
5650   if (!is_cost_valid) {
5651     // reset cost value
5652     av1_invalid_rd_stats(rd_stats);
5653   }
5654   return is_cost_valid;
5655 }
5656 
find_tx_size_rd_info(TXB_RD_RECORD * cur_record,const uint32_t hash)5657 static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record,
5658                                 const uint32_t hash) {
5659   // Linear search through the circular buffer to find matching hash.
5660   for (int i = cur_record->index_start - 1; i >= 0; i--) {
5661     if (cur_record->hash_vals[i] == hash) return i;
5662   }
5663   for (int i = cur_record->num - 1; i >= cur_record->index_start; i--) {
5664     if (cur_record->hash_vals[i] == hash) return i;
5665   }
5666   int index;
5667   // If not found - add new RD info into the buffer and return its index
5668   if (cur_record->num < TX_SIZE_RD_RECORD_BUFFER_LEN) {
5669     index = (cur_record->index_start + cur_record->num) %
5670             TX_SIZE_RD_RECORD_BUFFER_LEN;
5671     cur_record->num++;
5672   } else {
5673     index = cur_record->index_start;
5674     cur_record->index_start =
5675         (cur_record->index_start + 1) % TX_SIZE_RD_RECORD_BUFFER_LEN;
5676   }
5677 
5678   cur_record->hash_vals[index] = hash;
5679   av1_zero(cur_record->tx_rd_info[index]);
5680   return index;
5681 }
5682 
5683 typedef struct {
5684   int leaf;
5685   int8_t children[4];
5686 } RD_RECORD_IDX_NODE;
5687 
5688 static const RD_RECORD_IDX_NODE rd_record_tree_8x8[] = {
5689   { 1, { 0 } },
5690 };
5691 
5692 static const RD_RECORD_IDX_NODE rd_record_tree_8x16[] = {
5693   { 0, { 1, 2, -1, -1 } },
5694   { 1, { 0, 0, 0, 0 } },
5695   { 1, { 0, 0, 0, 0 } },
5696 };
5697 
5698 static const RD_RECORD_IDX_NODE rd_record_tree_16x8[] = {
5699   { 0, { 1, 2, -1, -1 } },
5700   { 1, { 0 } },
5701   { 1, { 0 } },
5702 };
5703 
5704 static const RD_RECORD_IDX_NODE rd_record_tree_16x16[] = {
5705   { 0, { 1, 2, 3, 4 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } },
5706 };
5707 
5708 static const RD_RECORD_IDX_NODE rd_record_tree_1_2[] = {
5709   { 0, { 1, 2, -1, -1 } },
5710   { 0, { 3, 4, 5, 6 } },
5711   { 0, { 7, 8, 9, 10 } },
5712 };
5713 
5714 static const RD_RECORD_IDX_NODE rd_record_tree_2_1[] = {
5715   { 0, { 1, 2, -1, -1 } },
5716   { 0, { 3, 4, 7, 8 } },
5717   { 0, { 5, 6, 9, 10 } },
5718 };
5719 
5720 static const RD_RECORD_IDX_NODE rd_record_tree_sqr[] = {
5721   { 0, { 1, 2, 3, 4 } },     { 0, { 5, 6, 9, 10 } },    { 0, { 7, 8, 11, 12 } },
5722   { 0, { 13, 14, 17, 18 } }, { 0, { 15, 16, 19, 20 } },
5723 };
5724 
5725 static const RD_RECORD_IDX_NODE rd_record_tree_64x128[] = {
5726   { 0, { 2, 3, 4, 5 } },     { 0, { 6, 7, 8, 9 } },
5727   { 0, { 10, 11, 14, 15 } }, { 0, { 12, 13, 16, 17 } },
5728   { 0, { 18, 19, 22, 23 } }, { 0, { 20, 21, 24, 25 } },
5729   { 0, { 26, 27, 30, 31 } }, { 0, { 28, 29, 32, 33 } },
5730   { 0, { 34, 35, 38, 39 } }, { 0, { 36, 37, 40, 41 } },
5731 };
5732 
5733 static const RD_RECORD_IDX_NODE rd_record_tree_128x64[] = {
5734   { 0, { 2, 3, 6, 7 } },     { 0, { 4, 5, 8, 9 } },
5735   { 0, { 10, 11, 18, 19 } }, { 0, { 12, 13, 20, 21 } },
5736   { 0, { 14, 15, 22, 23 } }, { 0, { 16, 17, 24, 25 } },
5737   { 0, { 26, 27, 34, 35 } }, { 0, { 28, 29, 36, 37 } },
5738   { 0, { 30, 31, 38, 39 } }, { 0, { 32, 33, 40, 41 } },
5739 };
5740 
5741 static const RD_RECORD_IDX_NODE rd_record_tree_128x128[] = {
5742   { 0, { 4, 5, 8, 9 } },     { 0, { 6, 7, 10, 11 } },
5743   { 0, { 12, 13, 16, 17 } }, { 0, { 14, 15, 18, 19 } },
5744   { 0, { 20, 21, 28, 29 } }, { 0, { 22, 23, 30, 31 } },
5745   { 0, { 24, 25, 32, 33 } }, { 0, { 26, 27, 34, 35 } },
5746   { 0, { 36, 37, 44, 45 } }, { 0, { 38, 39, 46, 47 } },
5747   { 0, { 40, 41, 48, 49 } }, { 0, { 42, 43, 50, 51 } },
5748   { 0, { 52, 53, 60, 61 } }, { 0, { 54, 55, 62, 63 } },
5749   { 0, { 56, 57, 64, 65 } }, { 0, { 58, 59, 66, 67 } },
5750   { 0, { 68, 69, 76, 77 } }, { 0, { 70, 71, 78, 79 } },
5751   { 0, { 72, 73, 80, 81 } }, { 0, { 74, 75, 82, 83 } },
5752 };
5753 
5754 static const RD_RECORD_IDX_NODE rd_record_tree_1_4[] = {
5755   { 0, { 1, -1, 2, -1 } },
5756   { 0, { 3, 4, -1, -1 } },
5757   { 0, { 5, 6, -1, -1 } },
5758 };
5759 
5760 static const RD_RECORD_IDX_NODE rd_record_tree_4_1[] = {
5761   { 0, { 1, 2, -1, -1 } },
5762   { 0, { 3, 4, -1, -1 } },
5763   { 0, { 5, 6, -1, -1 } },
5764 };
5765 
5766 static const RD_RECORD_IDX_NODE *rd_record_tree[BLOCK_SIZES_ALL] = {
5767   NULL,                    // BLOCK_4X4
5768   NULL,                    // BLOCK_4X8
5769   NULL,                    // BLOCK_8X4
5770   rd_record_tree_8x8,      // BLOCK_8X8
5771   rd_record_tree_8x16,     // BLOCK_8X16
5772   rd_record_tree_16x8,     // BLOCK_16X8
5773   rd_record_tree_16x16,    // BLOCK_16X16
5774   rd_record_tree_1_2,      // BLOCK_16X32
5775   rd_record_tree_2_1,      // BLOCK_32X16
5776   rd_record_tree_sqr,      // BLOCK_32X32
5777   rd_record_tree_1_2,      // BLOCK_32X64
5778   rd_record_tree_2_1,      // BLOCK_64X32
5779   rd_record_tree_sqr,      // BLOCK_64X64
5780   rd_record_tree_64x128,   // BLOCK_64X128
5781   rd_record_tree_128x64,   // BLOCK_128X64
5782   rd_record_tree_128x128,  // BLOCK_128X128
5783   NULL,                    // BLOCK_4X16
5784   NULL,                    // BLOCK_16X4
5785   rd_record_tree_1_4,      // BLOCK_8X32
5786   rd_record_tree_4_1,      // BLOCK_32X8
5787   rd_record_tree_1_4,      // BLOCK_16X64
5788   rd_record_tree_4_1,      // BLOCK_64X16
5789 };
5790 
5791 static const int rd_record_tree_size[BLOCK_SIZES_ALL] = {
5792   0,                                                            // BLOCK_4X4
5793   0,                                                            // BLOCK_4X8
5794   0,                                                            // BLOCK_8X4
5795   sizeof(rd_record_tree_8x8) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X8
5796   sizeof(rd_record_tree_8x16) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_8X16
5797   sizeof(rd_record_tree_16x8) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_16X8
5798   sizeof(rd_record_tree_16x16) / sizeof(RD_RECORD_IDX_NODE),    // BLOCK_16X16
5799   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X32
5800   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X16
5801   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X32
5802   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X64
5803   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X32
5804   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X64
5805   sizeof(rd_record_tree_64x128) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_64X128
5806   sizeof(rd_record_tree_128x64) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_128X64
5807   sizeof(rd_record_tree_128x128) / sizeof(RD_RECORD_IDX_NODE),  // BLOCK_128X128
5808   0,                                                            // BLOCK_4X16
5809   0,                                                            // BLOCK_16X4
5810   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X32
5811   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X8
5812   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X64
5813   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X16
5814 };
5815 
init_rd_record_tree(TXB_RD_INFO_NODE * tree,BLOCK_SIZE bsize)5816 static INLINE void init_rd_record_tree(TXB_RD_INFO_NODE *tree,
5817                                        BLOCK_SIZE bsize) {
5818   const RD_RECORD_IDX_NODE *rd_record = rd_record_tree[bsize];
5819   const int size = rd_record_tree_size[bsize];
5820   for (int i = 0; i < size; ++i) {
5821     if (rd_record[i].leaf) {
5822       av1_zero(tree[i].children);
5823     } else {
5824       for (int j = 0; j < 4; ++j) {
5825         const int8_t idx = rd_record[i].children[j];
5826         tree[i].children[j] = idx > 0 ? &tree[idx] : NULL;
5827       }
5828     }
5829   }
5830 }
5831 
5832 // Go through all TX blocks that could be used in TX size search, compute
5833 // residual hash values for them and find matching RD info that stores previous
5834 // RD search results for these TX blocks. The idea is to prevent repeated
5835 // rate/distortion computations that happen because of the combination of
5836 // partition and TX size search. The resulting RD info records are returned in
5837 // the form of a quadtree for easier access in actual TX size search.
find_tx_size_rd_records(MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,TXB_RD_INFO_NODE * dst_rd_info)5838 static int find_tx_size_rd_records(MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
5839                                    int mi_col, TXB_RD_INFO_NODE *dst_rd_info) {
5840   TXB_RD_RECORD *rd_records_table[4] = { x->txb_rd_record_8X8,
5841                                          x->txb_rd_record_16X16,
5842                                          x->txb_rd_record_32X32,
5843                                          x->txb_rd_record_64X64 };
5844   const TX_SIZE max_square_tx_size = max_txsize_lookup[bsize];
5845   const int bw = block_size_wide[bsize];
5846   const int bh = block_size_high[bsize];
5847 
5848   // Hashing is performed only for square TX sizes larger than TX_4X4
5849   if (max_square_tx_size < TX_8X8) return 0;
5850   const int diff_stride = bw;
5851   const struct macroblock_plane *const p = &x->plane[0];
5852   const int16_t *diff = &p->src_diff[0];
5853   init_rd_record_tree(dst_rd_info, bsize);
5854   // Coordinates of the top-left corner of current block within the superblock
5855   // measured in pixels:
5856   const int mi_row_in_sb = (mi_row % MAX_MIB_SIZE) << MI_SIZE_LOG2;
5857   const int mi_col_in_sb = (mi_col % MAX_MIB_SIZE) << MI_SIZE_LOG2;
5858   int cur_rd_info_idx = 0;
5859   int cur_tx_depth = 0;
5860   TX_SIZE cur_tx_size = max_txsize_rect_lookup[bsize];
5861   while (cur_tx_depth <= MAX_VARTX_DEPTH) {
5862     const int cur_tx_bw = tx_size_wide[cur_tx_size];
5863     const int cur_tx_bh = tx_size_high[cur_tx_size];
5864     if (cur_tx_bw < 8 || cur_tx_bh < 8) break;
5865     const TX_SIZE next_tx_size = sub_tx_size_map[cur_tx_size];
5866     const int tx_size_idx = cur_tx_size - TX_8X8;
5867     for (int row = 0; row < bh; row += cur_tx_bh) {
5868       for (int col = 0; col < bw; col += cur_tx_bw) {
5869         if (cur_tx_bw != cur_tx_bh) {
5870           // Use dummy nodes for all rectangular transforms within the
5871           // TX size search tree.
5872           dst_rd_info[cur_rd_info_idx].rd_info_array = NULL;
5873         } else {
5874           // Get spatial location of this TX block within the superblock
5875           // (measured in cur_tx_bsize units).
5876           const int row_in_sb = (mi_row_in_sb + row) / cur_tx_bh;
5877           const int col_in_sb = (mi_col_in_sb + col) / cur_tx_bw;
5878 
5879           int16_t hash_data[MAX_SB_SQUARE];
5880           int16_t *cur_hash_row = hash_data;
5881           const int16_t *cur_diff_row = diff + row * diff_stride + col;
5882           for (int i = 0; i < cur_tx_bh; i++) {
5883             memcpy(cur_hash_row, cur_diff_row, sizeof(*hash_data) * cur_tx_bw);
5884             cur_hash_row += cur_tx_bw;
5885             cur_diff_row += diff_stride;
5886           }
5887           const int hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
5888                                                 (uint8_t *)hash_data,
5889                                                 2 * cur_tx_bw * cur_tx_bh);
5890           // Find corresponding RD info based on the hash value.
5891           const int record_idx =
5892               row_in_sb * (MAX_MIB_SIZE >> (tx_size_idx + 1)) + col_in_sb;
5893           TXB_RD_RECORD *records = &rd_records_table[tx_size_idx][record_idx];
5894           int idx = find_tx_size_rd_info(records, hash);
5895           dst_rd_info[cur_rd_info_idx].rd_info_array =
5896               &records->tx_rd_info[idx];
5897         }
5898         ++cur_rd_info_idx;
5899       }
5900     }
5901     cur_tx_size = next_tx_size;
5902     ++cur_tx_depth;
5903   }
5904   return 1;
5905 }
5906 
5907 // Search for best transform size and type for luma inter blocks.
pick_tx_size_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int mi_row,int mi_col,int64_t ref_best_rd)5908 static void pick_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
5909                                   RD_STATS *rd_stats, BLOCK_SIZE bsize,
5910                                   int mi_row, int mi_col, int64_t ref_best_rd) {
5911   const AV1_COMMON *cm = &cpi->common;
5912   MACROBLOCKD *const xd = &x->e_mbd;
5913   assert(is_inter_block(xd->mi[0]));
5914 
5915   av1_invalid_rd_stats(rd_stats);
5916 
5917   if (cpi->sf.model_based_prune_tx_search_level && ref_best_rd != INT64_MAX) {
5918     int model_rate;
5919     int64_t model_dist;
5920     int model_skip;
5921     model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
5922         cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &model_rate, &model_dist,
5923         &model_skip, NULL, NULL, NULL, NULL);
5924     const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
5925     // If the modeled rd is a lot worse than the best so far, breakout.
5926     // TODO(debargha, urvang): Improve the model and make the check below
5927     // tighter.
5928     assert(cpi->sf.model_based_prune_tx_search_level >= 0 &&
5929            cpi->sf.model_based_prune_tx_search_level <= 2);
5930     static const int prune_factor_by8[] = { 3, 5 };
5931     if (!model_skip &&
5932         ((model_rd *
5933           prune_factor_by8[cpi->sf.model_based_prune_tx_search_level - 1]) >>
5934          3) > ref_best_rd)
5935       return;
5936   }
5937 
5938   uint32_t hash = 0;
5939   int32_t match_index = -1;
5940   MB_RD_RECORD *mb_rd_record = NULL;
5941   const int within_border =
5942       mi_row >= xd->tile.mi_row_start &&
5943       (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
5944       mi_col >= xd->tile.mi_col_start &&
5945       (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
5946   const int is_mb_rd_hash_enabled = (within_border && cpi->sf.use_mb_rd_hash);
5947   const int n4 = bsize_to_num_blk(bsize);
5948   if (is_mb_rd_hash_enabled) {
5949     hash = get_block_residue_hash(x, bsize);
5950     mb_rd_record = &x->mb_rd_record;
5951     match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
5952     if (match_index != -1) {
5953       MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
5954       fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
5955       return;
5956     }
5957   }
5958 
5959   // If we predict that skip is the optimal RD decision - set the respective
5960   // context and terminate early.
5961   int64_t dist;
5962   if (cpi->sf.tx_type_search.use_skip_flag_prediction &&
5963       predict_skip_flag(x, bsize, &dist, cm->reduced_tx_set_used)) {
5964     set_skip_flag(x, rd_stats, bsize, dist);
5965     // Save the RD search results into tx_rd_record.
5966     if (is_mb_rd_hash_enabled)
5967       save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
5968     return;
5969   }
5970 #if CONFIG_SPEED_STATS
5971   ++x->tx_search_count;
5972 #endif  // CONFIG_SPEED_STATS
5973 
5974   // Precompute residual hashes and find existing or add new RD records to
5975   // store and reuse rate and distortion values to speed up TX size search.
5976   TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
5977   int found_rd_info = 0;
5978   if (ref_best_rd != INT64_MAX && within_border && cpi->sf.use_inter_txb_hash) {
5979     found_rd_info =
5980         find_tx_size_rd_records(x, bsize, mi_row, mi_col, matched_rd_info);
5981   }
5982 
5983   // Get the tx_size 1 level down
5984   const TX_SIZE min_tx_size = sub_tx_size_map[max_txsize_rect_lookup[bsize]];
5985   const TxSetType tx_set_type =
5986       av1_get_ext_tx_set_type(min_tx_size, 1, cm->reduced_tx_set_used);
5987   prune_tx(cpi, bsize, x, xd, tx_set_type);
5988 
5989   int found = 0;
5990   RD_STATS this_rd_stats;
5991   av1_init_rd_stats(&this_rd_stats);
5992   const int64_t rd =
5993       select_tx_size_and_type(cpi, x, &this_rd_stats, bsize, ref_best_rd,
5994                               found_rd_info ? matched_rd_info : NULL);
5995 
5996   if (rd < INT64_MAX) {
5997     *rd_stats = this_rd_stats;
5998     found = 1;
5999   }
6000 
6001   // Reset the pruning flags.
6002   av1_zero(x->tx_search_prune);
6003   x->tx_split_prune_flag = 0;
6004 
6005   // We should always find at least one candidate unless ref_best_rd is less
6006   // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
6007   // might have failed to find something better)
6008   assert(IMPLIES(!found, ref_best_rd != INT64_MAX));
6009   if (!found) return;
6010 
6011   // Save the RD search results into tx_rd_record.
6012   if (is_mb_rd_hash_enabled) {
6013     assert(mb_rd_record != NULL);
6014     save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
6015   }
6016 }
6017 
model_rd_for_sb_with_fullrdy(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)6018 static void model_rd_for_sb_with_fullrdy(
6019     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
6020     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
6021     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
6022     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
6023   const int ref = xd->mi[0]->ref_frame[0];
6024 
6025   int64_t rate_sum = 0;
6026   int64_t dist_sum = 0;
6027   int64_t total_sse = 0;
6028 
6029   for (int plane = plane_from; plane <= plane_to; ++plane) {
6030     struct macroblock_plane *const p = &x->plane[plane];
6031     struct macroblockd_plane *const pd = &xd->plane[plane];
6032     const BLOCK_SIZE plane_bsize =
6033         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
6034     const int bw = block_size_wide[plane_bsize];
6035     const int bh = block_size_high[plane_bsize];
6036     int64_t sse;
6037     int rate;
6038     int64_t dist;
6039 
6040     if (x->skip_chroma_rd && plane) continue;
6041 
6042     if (is_cur_buf_hbd(xd)) {
6043       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
6044                            pd->dst.stride, bw, bh);
6045     } else {
6046       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
6047                     bh);
6048     }
6049     sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2);
6050 
6051     RD_STATS rd_stats;
6052     if (plane == 0) {
6053       pick_tx_size_type_yrd(cpi, x, &rd_stats, bsize, mi_row, mi_col,
6054                             INT64_MAX);
6055       if (rd_stats.invalid_rate) {
6056         rate = 0;
6057         dist = sse << 4;
6058       } else {
6059         rate = rd_stats.rate;
6060         dist = rd_stats.dist;
6061       }
6062     } else {
6063       model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
6064                             &dist);
6065     }
6066 
6067     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
6068 
6069     total_sse += sse;
6070     rate_sum += rate;
6071     dist_sum += dist;
6072 
6073     if (plane_rate) plane_rate[plane] = rate;
6074     if (plane_sse) plane_sse[plane] = sse;
6075     if (plane_dist) plane_dist[plane] = dist;
6076   }
6077 
6078   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
6079   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
6080   *out_rate_sum = (int)rate_sum;
6081   *out_dist_sum = dist_sum;
6082 }
6083 
rd_pick_palette_intra_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,int dc_mode_cost,uint8_t * best_palette_color_map,MB_MODE_INFO * const best_mbmi,int64_t * best_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable)6084 static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x,
6085                                        int dc_mode_cost,
6086                                        uint8_t *best_palette_color_map,
6087                                        MB_MODE_INFO *const best_mbmi,
6088                                        int64_t *best_rd, int *rate,
6089                                        int *rate_tokenonly, int64_t *distortion,
6090                                        int *skippable) {
6091   MACROBLOCKD *const xd = &x->e_mbd;
6092   MB_MODE_INFO *const mbmi = xd->mi[0];
6093   assert(!is_inter_block(mbmi));
6094   assert(
6095       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type));
6096   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
6097   const BLOCK_SIZE bsize = mbmi->sb_type;
6098   const SequenceHeader *const seq_params = &cpi->common.seq_params;
6099   int this_rate;
6100   int64_t this_rd;
6101   int colors_u, colors_v, colors;
6102   const int src_stride = x->plane[1].src.stride;
6103   const uint8_t *const src_u = x->plane[1].src.buf;
6104   const uint8_t *const src_v = x->plane[2].src.buf;
6105   uint8_t *const color_map = xd->plane[1].color_index_map;
6106   RD_STATS tokenonly_rd_stats;
6107   int plane_block_width, plane_block_height, rows, cols;
6108   av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
6109                            &plane_block_height, &rows, &cols);
6110 
6111   mbmi->uv_mode = UV_DC_PRED;
6112 
6113   int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
6114   if (seq_params->use_highbitdepth) {
6115     colors_u = av1_count_colors_highbd(src_u, src_stride, rows, cols,
6116                                        seq_params->bit_depth, count_buf);
6117     colors_v = av1_count_colors_highbd(src_v, src_stride, rows, cols,
6118                                        seq_params->bit_depth, count_buf);
6119   } else {
6120     colors_u = av1_count_colors(src_u, src_stride, rows, cols, count_buf);
6121     colors_v = av1_count_colors(src_v, src_stride, rows, cols, count_buf);
6122   }
6123 
6124   uint16_t color_cache[2 * PALETTE_MAX_SIZE];
6125   const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
6126 
6127   colors = colors_u > colors_v ? colors_u : colors_v;
6128   if (colors > 1 && colors <= 64) {
6129     int r, c, n, i, j;
6130     const int max_itr = 50;
6131     int lb_u, ub_u, val_u;
6132     int lb_v, ub_v, val_v;
6133     int *const data = x->palette_buffer->kmeans_data_buf;
6134     int centroids[2 * PALETTE_MAX_SIZE];
6135 
6136     uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
6137     uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v);
6138     if (seq_params->use_highbitdepth) {
6139       lb_u = src_u16[0];
6140       ub_u = src_u16[0];
6141       lb_v = src_v16[0];
6142       ub_v = src_v16[0];
6143     } else {
6144       lb_u = src_u[0];
6145       ub_u = src_u[0];
6146       lb_v = src_v[0];
6147       ub_v = src_v[0];
6148     }
6149 
6150     for (r = 0; r < rows; ++r) {
6151       for (c = 0; c < cols; ++c) {
6152         if (seq_params->use_highbitdepth) {
6153           val_u = src_u16[r * src_stride + c];
6154           val_v = src_v16[r * src_stride + c];
6155           data[(r * cols + c) * 2] = val_u;
6156           data[(r * cols + c) * 2 + 1] = val_v;
6157         } else {
6158           val_u = src_u[r * src_stride + c];
6159           val_v = src_v[r * src_stride + c];
6160           data[(r * cols + c) * 2] = val_u;
6161           data[(r * cols + c) * 2 + 1] = val_v;
6162         }
6163         if (val_u < lb_u)
6164           lb_u = val_u;
6165         else if (val_u > ub_u)
6166           ub_u = val_u;
6167         if (val_v < lb_v)
6168           lb_v = val_v;
6169         else if (val_v > ub_v)
6170           ub_v = val_v;
6171       }
6172     }
6173 
6174     for (n = colors > PALETTE_MAX_SIZE ? PALETTE_MAX_SIZE : colors; n >= 2;
6175          --n) {
6176       for (i = 0; i < n; ++i) {
6177         centroids[i * 2] = lb_u + (2 * i + 1) * (ub_u - lb_u) / n / 2;
6178         centroids[i * 2 + 1] = lb_v + (2 * i + 1) * (ub_v - lb_v) / n / 2;
6179       }
6180       av1_k_means(data, centroids, color_map, rows * cols, n, 2, max_itr);
6181       optimize_palette_colors(color_cache, n_cache, n, 2, centroids);
6182       // Sort the U channel colors in ascending order.
6183       for (i = 0; i < 2 * (n - 1); i += 2) {
6184         int min_idx = i;
6185         int min_val = centroids[i];
6186         for (j = i + 2; j < 2 * n; j += 2)
6187           if (centroids[j] < min_val) min_val = centroids[j], min_idx = j;
6188         if (min_idx != i) {
6189           int temp_u = centroids[i], temp_v = centroids[i + 1];
6190           centroids[i] = centroids[min_idx];
6191           centroids[i + 1] = centroids[min_idx + 1];
6192           centroids[min_idx] = temp_u, centroids[min_idx + 1] = temp_v;
6193         }
6194       }
6195       av1_calc_indices(data, centroids, color_map, rows * cols, n, 2);
6196       extend_palette_color_map(color_map, cols, rows, plane_block_width,
6197                                plane_block_height);
6198       pmi->palette_size[1] = n;
6199       for (i = 1; i < 3; ++i) {
6200         for (j = 0; j < n; ++j) {
6201           if (seq_params->use_highbitdepth)
6202             pmi->palette_colors[i * PALETTE_MAX_SIZE + j] = clip_pixel_highbd(
6203                 (int)centroids[j * 2 + i - 1], seq_params->bit_depth);
6204           else
6205             pmi->palette_colors[i * PALETTE_MAX_SIZE + j] =
6206                 clip_pixel((int)centroids[j * 2 + i - 1]);
6207         }
6208       }
6209 
6210       super_block_uvrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
6211       if (tokenonly_rd_stats.rate == INT_MAX) continue;
6212       this_rate = tokenonly_rd_stats.rate +
6213                   intra_mode_info_cost_uv(cpi, x, mbmi, bsize, dc_mode_cost);
6214       this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
6215       if (this_rd < *best_rd) {
6216         *best_rd = this_rd;
6217         *best_mbmi = *mbmi;
6218         memcpy(best_palette_color_map, color_map,
6219                plane_block_width * plane_block_height *
6220                    sizeof(best_palette_color_map[0]));
6221         *rate = this_rate;
6222         *distortion = tokenonly_rd_stats.dist;
6223         *rate_tokenonly = tokenonly_rd_stats.rate;
6224         *skippable = tokenonly_rd_stats.skip;
6225       }
6226     }
6227   }
6228   if (best_mbmi->palette_mode_info.palette_size[1] > 0) {
6229     memcpy(color_map, best_palette_color_map,
6230            plane_block_width * plane_block_height *
6231                sizeof(best_palette_color_map[0]));
6232   }
6233 }
6234 
6235 // Run RD calculation with given chroma intra prediction angle., and return
6236 // the RD cost. Update the best mode info. if the RD cost is the best so far.
pick_intra_angle_routine_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int rate_overhead,int64_t best_rd_in,int * rate,RD_STATS * rd_stats,int * best_angle_delta,int64_t * best_rd)6237 static int64_t pick_intra_angle_routine_sbuv(
6238     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
6239     int rate_overhead, int64_t best_rd_in, int *rate, RD_STATS *rd_stats,
6240     int *best_angle_delta, int64_t *best_rd) {
6241   MB_MODE_INFO *mbmi = x->e_mbd.mi[0];
6242   assert(!is_inter_block(mbmi));
6243   int this_rate;
6244   int64_t this_rd;
6245   RD_STATS tokenonly_rd_stats;
6246 
6247   if (!super_block_uvrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd_in))
6248     return INT64_MAX;
6249   this_rate = tokenonly_rd_stats.rate +
6250               intra_mode_info_cost_uv(cpi, x, mbmi, bsize, rate_overhead);
6251   this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
6252   if (this_rd < *best_rd) {
6253     *best_rd = this_rd;
6254     *best_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
6255     *rate = this_rate;
6256     rd_stats->rate = tokenonly_rd_stats.rate;
6257     rd_stats->dist = tokenonly_rd_stats.dist;
6258     rd_stats->skip = tokenonly_rd_stats.skip;
6259   }
6260   return this_rd;
6261 }
6262 
6263 // With given chroma directional intra prediction mode, pick the best angle
6264 // delta. Return true if a RD cost that is smaller than the input one is found.
rd_pick_intra_angle_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int rate_overhead,int64_t best_rd,int * rate,RD_STATS * rd_stats)6265 static int rd_pick_intra_angle_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x,
6266                                     BLOCK_SIZE bsize, int rate_overhead,
6267                                     int64_t best_rd, int *rate,
6268                                     RD_STATS *rd_stats) {
6269   MACROBLOCKD *const xd = &x->e_mbd;
6270   MB_MODE_INFO *mbmi = xd->mi[0];
6271   assert(!is_inter_block(mbmi));
6272   int i, angle_delta, best_angle_delta = 0;
6273   int64_t this_rd, best_rd_in, rd_cost[2 * (MAX_ANGLE_DELTA + 2)];
6274 
6275   rd_stats->rate = INT_MAX;
6276   rd_stats->skip = 0;
6277   rd_stats->dist = INT64_MAX;
6278   for (i = 0; i < 2 * (MAX_ANGLE_DELTA + 2); ++i) rd_cost[i] = INT64_MAX;
6279 
6280   for (angle_delta = 0; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
6281     for (i = 0; i < 2; ++i) {
6282       best_rd_in = (best_rd == INT64_MAX)
6283                        ? INT64_MAX
6284                        : (best_rd + (best_rd >> ((angle_delta == 0) ? 3 : 5)));
6285       mbmi->angle_delta[PLANE_TYPE_UV] = (1 - 2 * i) * angle_delta;
6286       this_rd = pick_intra_angle_routine_sbuv(cpi, x, bsize, rate_overhead,
6287                                               best_rd_in, rate, rd_stats,
6288                                               &best_angle_delta, &best_rd);
6289       rd_cost[2 * angle_delta + i] = this_rd;
6290       if (angle_delta == 0) {
6291         if (this_rd == INT64_MAX) return 0;
6292         rd_cost[1] = this_rd;
6293         break;
6294       }
6295     }
6296   }
6297 
6298   assert(best_rd != INT64_MAX);
6299   for (angle_delta = 1; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
6300     int64_t rd_thresh;
6301     for (i = 0; i < 2; ++i) {
6302       int skip_search = 0;
6303       rd_thresh = best_rd + (best_rd >> 5);
6304       if (rd_cost[2 * (angle_delta + 1) + i] > rd_thresh &&
6305           rd_cost[2 * (angle_delta - 1) + i] > rd_thresh)
6306         skip_search = 1;
6307       if (!skip_search) {
6308         mbmi->angle_delta[PLANE_TYPE_UV] = (1 - 2 * i) * angle_delta;
6309         pick_intra_angle_routine_sbuv(cpi, x, bsize, rate_overhead, best_rd,
6310                                       rate, rd_stats, &best_angle_delta,
6311                                       &best_rd);
6312       }
6313     }
6314   }
6315 
6316   mbmi->angle_delta[PLANE_TYPE_UV] = best_angle_delta;
6317   return rd_stats->rate != INT_MAX;
6318 }
6319 
6320 #define PLANE_SIGN_TO_JOINT_SIGN(plane, a, b) \
6321   (plane == CFL_PRED_U ? a * CFL_SIGNS + b - 1 : b * CFL_SIGNS + a - 1)
cfl_rd_pick_alpha(MACROBLOCK * const x,const AV1_COMP * const cpi,TX_SIZE tx_size,int64_t best_rd)6322 static int cfl_rd_pick_alpha(MACROBLOCK *const x, const AV1_COMP *const cpi,
6323                              TX_SIZE tx_size, int64_t best_rd) {
6324   MACROBLOCKD *const xd = &x->e_mbd;
6325   MB_MODE_INFO *const mbmi = xd->mi[0];
6326 
6327   const BLOCK_SIZE bsize = mbmi->sb_type;
6328 #if CONFIG_DEBUG
6329   assert(is_cfl_allowed(xd) && cpi->oxcf.enable_cfl_intra);
6330   const int ssx = xd->plane[AOM_PLANE_U].subsampling_x;
6331   const int ssy = xd->plane[AOM_PLANE_U].subsampling_y;
6332   const BLOCK_SIZE plane_bsize = get_plane_block_size(mbmi->sb_type, ssx, ssy);
6333   (void)plane_bsize;
6334   assert(plane_bsize < BLOCK_SIZES_ALL);
6335   if (!xd->lossless[mbmi->segment_id]) {
6336     assert(block_size_wide[plane_bsize] == tx_size_wide[tx_size]);
6337     assert(block_size_high[plane_bsize] == tx_size_high[tx_size]);
6338   }
6339 #endif  // CONFIG_DEBUG
6340 
6341   xd->cfl.use_dc_pred_cache = 1;
6342   const int64_t mode_rd =
6343       RDCOST(x->rdmult,
6344              x->intra_uv_mode_cost[CFL_ALLOWED][mbmi->mode][UV_CFL_PRED], 0);
6345   int64_t best_rd_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
6346   int best_c[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
6347 #if CONFIG_DEBUG
6348   int best_rate_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
6349 #endif  // CONFIG_DEBUG
6350 
6351   for (int plane = 0; plane < CFL_PRED_PLANES; plane++) {
6352     RD_STATS rd_stats;
6353     av1_init_rd_stats(&rd_stats);
6354     for (int joint_sign = 0; joint_sign < CFL_JOINT_SIGNS; joint_sign++) {
6355       best_rd_uv[joint_sign][plane] = INT64_MAX;
6356       best_c[joint_sign][plane] = 0;
6357     }
6358     // Collect RD stats for an alpha value of zero in this plane.
6359     // Skip i == CFL_SIGN_ZERO as (0, 0) is invalid.
6360     for (int i = CFL_SIGN_NEG; i < CFL_SIGNS; i++) {
6361       const int joint_sign = PLANE_SIGN_TO_JOINT_SIGN(plane, CFL_SIGN_ZERO, i);
6362       if (i == CFL_SIGN_NEG) {
6363         mbmi->cfl_alpha_idx = 0;
6364         mbmi->cfl_alpha_signs = joint_sign;
6365         txfm_rd_in_plane(x, cpi, &rd_stats, best_rd, 0, plane + 1, bsize,
6366                          tx_size, cpi->sf.use_fast_coef_costing, FTXS_NONE, 0);
6367         if (rd_stats.rate == INT_MAX) break;
6368       }
6369       const int alpha_rate = x->cfl_cost[joint_sign][plane][0];
6370       best_rd_uv[joint_sign][plane] =
6371           RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
6372 #if CONFIG_DEBUG
6373       best_rate_uv[joint_sign][plane] = rd_stats.rate;
6374 #endif  // CONFIG_DEBUG
6375     }
6376   }
6377 
6378   int best_joint_sign = -1;
6379 
6380   for (int plane = 0; plane < CFL_PRED_PLANES; plane++) {
6381     for (int pn_sign = CFL_SIGN_NEG; pn_sign < CFL_SIGNS; pn_sign++) {
6382       int progress = 0;
6383       for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
6384         int flag = 0;
6385         RD_STATS rd_stats;
6386         if (c > 2 && progress < c) break;
6387         av1_init_rd_stats(&rd_stats);
6388         for (int i = 0; i < CFL_SIGNS; i++) {
6389           const int joint_sign = PLANE_SIGN_TO_JOINT_SIGN(plane, pn_sign, i);
6390           if (i == 0) {
6391             mbmi->cfl_alpha_idx = (c << CFL_ALPHABET_SIZE_LOG2) + c;
6392             mbmi->cfl_alpha_signs = joint_sign;
6393             txfm_rd_in_plane(x, cpi, &rd_stats, best_rd, 0, plane + 1, bsize,
6394                              tx_size, cpi->sf.use_fast_coef_costing, FTXS_NONE,
6395                              0);
6396             if (rd_stats.rate == INT_MAX) break;
6397           }
6398           const int alpha_rate = x->cfl_cost[joint_sign][plane][c];
6399           int64_t this_rd =
6400               RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
6401           if (this_rd >= best_rd_uv[joint_sign][plane]) continue;
6402           best_rd_uv[joint_sign][plane] = this_rd;
6403           best_c[joint_sign][plane] = c;
6404 #if CONFIG_DEBUG
6405           best_rate_uv[joint_sign][plane] = rd_stats.rate;
6406 #endif  // CONFIG_DEBUG
6407           flag = 2;
6408           if (best_rd_uv[joint_sign][!plane] == INT64_MAX) continue;
6409           this_rd += mode_rd + best_rd_uv[joint_sign][!plane];
6410           if (this_rd >= best_rd) continue;
6411           best_rd = this_rd;
6412           best_joint_sign = joint_sign;
6413         }
6414         progress += flag;
6415       }
6416     }
6417   }
6418 
6419   int best_rate_overhead = INT_MAX;
6420   int ind = 0;
6421   if (best_joint_sign >= 0) {
6422     const int u = best_c[best_joint_sign][CFL_PRED_U];
6423     const int v = best_c[best_joint_sign][CFL_PRED_V];
6424     ind = (u << CFL_ALPHABET_SIZE_LOG2) + v;
6425     best_rate_overhead = x->cfl_cost[best_joint_sign][CFL_PRED_U][u] +
6426                          x->cfl_cost[best_joint_sign][CFL_PRED_V][v];
6427 #if CONFIG_DEBUG
6428     xd->cfl.rate = x->intra_uv_mode_cost[CFL_ALLOWED][mbmi->mode][UV_CFL_PRED] +
6429                    best_rate_overhead +
6430                    best_rate_uv[best_joint_sign][CFL_PRED_U] +
6431                    best_rate_uv[best_joint_sign][CFL_PRED_V];
6432 #endif  // CONFIG_DEBUG
6433   } else {
6434     best_joint_sign = 0;
6435   }
6436 
6437   mbmi->cfl_alpha_idx = ind;
6438   mbmi->cfl_alpha_signs = best_joint_sign;
6439   xd->cfl.use_dc_pred_cache = 0;
6440   xd->cfl.dc_pred_is_cached[0] = 0;
6441   xd->cfl.dc_pred_is_cached[1] = 0;
6442   return best_rate_overhead;
6443 }
6444 
init_sbuv_mode(MB_MODE_INFO * const mbmi)6445 static void init_sbuv_mode(MB_MODE_INFO *const mbmi) {
6446   mbmi->uv_mode = UV_DC_PRED;
6447   mbmi->palette_mode_info.palette_size[1] = 0;
6448 }
6449 
rd_pick_intra_sbuv_mode(const AV1_COMP * const cpi,MACROBLOCK * x,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,TX_SIZE max_tx_size)6450 static int64_t rd_pick_intra_sbuv_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
6451                                        int *rate, int *rate_tokenonly,
6452                                        int64_t *distortion, int *skippable,
6453                                        BLOCK_SIZE bsize, TX_SIZE max_tx_size) {
6454   MACROBLOCKD *xd = &x->e_mbd;
6455   MB_MODE_INFO *mbmi = xd->mi[0];
6456   assert(!is_inter_block(mbmi));
6457   MB_MODE_INFO best_mbmi = *mbmi;
6458   int64_t best_rd = INT64_MAX, this_rd;
6459 
6460   for (int mode_idx = 0; mode_idx < UV_INTRA_MODES; ++mode_idx) {
6461     int this_rate;
6462     RD_STATS tokenonly_rd_stats;
6463     UV_PREDICTION_MODE mode = uv_rd_search_mode_order[mode_idx];
6464     const int is_directional_mode = av1_is_directional_mode(get_uv_mode(mode));
6465     if (!(cpi->sf.intra_uv_mode_mask[txsize_sqr_up_map[max_tx_size]] &
6466           (1 << mode)))
6467       continue;
6468     if (!cpi->oxcf.enable_smooth_intra && mode >= UV_SMOOTH_PRED &&
6469         mode <= UV_SMOOTH_H_PRED)
6470       continue;
6471 
6472     if (!cpi->oxcf.enable_paeth_intra && mode == UV_PAETH_PRED) continue;
6473 
6474     mbmi->uv_mode = mode;
6475     int cfl_alpha_rate = 0;
6476     if (mode == UV_CFL_PRED) {
6477       if (!is_cfl_allowed(xd) || !cpi->oxcf.enable_cfl_intra) continue;
6478       assert(!is_directional_mode);
6479       const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
6480       cfl_alpha_rate = cfl_rd_pick_alpha(x, cpi, uv_tx_size, best_rd);
6481       if (cfl_alpha_rate == INT_MAX) continue;
6482     }
6483     mbmi->angle_delta[PLANE_TYPE_UV] = 0;
6484     if (is_directional_mode && av1_use_angle_delta(mbmi->sb_type) &&
6485         cpi->oxcf.enable_angle_delta) {
6486       const int rate_overhead =
6487           x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mode];
6488       if (!rd_pick_intra_angle_sbuv(cpi, x, bsize, rate_overhead, best_rd,
6489                                     &this_rate, &tokenonly_rd_stats))
6490         continue;
6491     } else {
6492       if (!super_block_uvrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd)) {
6493         continue;
6494       }
6495     }
6496     const int mode_cost =
6497         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mode] +
6498         cfl_alpha_rate;
6499     this_rate = tokenonly_rd_stats.rate +
6500                 intra_mode_info_cost_uv(cpi, x, mbmi, bsize, mode_cost);
6501     if (mode == UV_CFL_PRED) {
6502       assert(is_cfl_allowed(xd) && cpi->oxcf.enable_cfl_intra);
6503 #if CONFIG_DEBUG
6504       if (!xd->lossless[mbmi->segment_id])
6505         assert(xd->cfl.rate == tokenonly_rd_stats.rate + mode_cost);
6506 #endif  // CONFIG_DEBUG
6507     }
6508     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
6509 
6510     if (this_rd < best_rd) {
6511       best_mbmi = *mbmi;
6512       best_rd = this_rd;
6513       *rate = this_rate;
6514       *rate_tokenonly = tokenonly_rd_stats.rate;
6515       *distortion = tokenonly_rd_stats.dist;
6516       *skippable = tokenonly_rd_stats.skip;
6517     }
6518   }
6519 
6520   const int try_palette =
6521       cpi->oxcf.enable_palette &&
6522       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
6523   if (try_palette) {
6524     uint8_t *best_palette_color_map = x->palette_buffer->best_palette_color_map;
6525     rd_pick_palette_intra_sbuv(
6526         cpi, x,
6527         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][UV_DC_PRED],
6528         best_palette_color_map, &best_mbmi, &best_rd, rate, rate_tokenonly,
6529         distortion, skippable);
6530   }
6531 
6532   *mbmi = best_mbmi;
6533   // Make sure we actually chose a mode
6534   assert(best_rd < INT64_MAX);
6535   return best_rd;
6536 }
6537 
choose_intra_uv_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,TX_SIZE max_tx_size,int * rate_uv,int * rate_uv_tokenonly,int64_t * dist_uv,int * skip_uv,UV_PREDICTION_MODE * mode_uv)6538 static void choose_intra_uv_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
6539                                  BLOCK_SIZE bsize, TX_SIZE max_tx_size,
6540                                  int *rate_uv, int *rate_uv_tokenonly,
6541                                  int64_t *dist_uv, int *skip_uv,
6542                                  UV_PREDICTION_MODE *mode_uv) {
6543   const AV1_COMMON *const cm = &cpi->common;
6544   MACROBLOCKD *xd = &x->e_mbd;
6545   MB_MODE_INFO *mbmi = xd->mi[0];
6546   const int mi_row = -xd->mb_to_top_edge >> (3 + MI_SIZE_LOG2);
6547   const int mi_col = -xd->mb_to_left_edge >> (3 + MI_SIZE_LOG2);
6548   // Use an estimated rd for uv_intra based on DC_PRED if the
6549   // appropriate speed flag is set.
6550   init_sbuv_mode(mbmi);
6551   if (x->skip_chroma_rd) {
6552     *rate_uv = 0;
6553     *rate_uv_tokenonly = 0;
6554     *dist_uv = 0;
6555     *skip_uv = 1;
6556     *mode_uv = UV_DC_PRED;
6557     return;
6558   }
6559   xd->cfl.is_chroma_reference =
6560       is_chroma_reference(mi_row, mi_col, bsize, cm->seq_params.subsampling_x,
6561                           cm->seq_params.subsampling_y);
6562   bsize = scale_chroma_bsize(bsize, xd->plane[AOM_PLANE_U].subsampling_x,
6563                              xd->plane[AOM_PLANE_U].subsampling_y);
6564   // Only store reconstructed luma when there's chroma RDO. When there's no
6565   // chroma RDO, the reconstructed luma will be stored in encode_superblock().
6566   xd->cfl.store_y = store_cfl_required_rdo(cm, x);
6567   if (xd->cfl.store_y) {
6568     // Restore reconstructed luma values.
6569     av1_encode_intra_block_plane(cpi, x, mbmi->sb_type, AOM_PLANE_Y,
6570                                  cpi->optimize_seg_arr[mbmi->segment_id],
6571                                  mi_row, mi_col);
6572     xd->cfl.store_y = 0;
6573   }
6574   rd_pick_intra_sbuv_mode(cpi, x, rate_uv, rate_uv_tokenonly, dist_uv, skip_uv,
6575                           bsize, max_tx_size);
6576   *mode_uv = mbmi->uv_mode;
6577 }
6578 
cost_mv_ref(const MACROBLOCK * const x,PREDICTION_MODE mode,int16_t mode_context)6579 static int cost_mv_ref(const MACROBLOCK *const x, PREDICTION_MODE mode,
6580                        int16_t mode_context) {
6581   if (is_inter_compound_mode(mode)) {
6582     return x
6583         ->inter_compound_mode_cost[mode_context][INTER_COMPOUND_OFFSET(mode)];
6584   }
6585 
6586   int mode_cost = 0;
6587   int16_t mode_ctx = mode_context & NEWMV_CTX_MASK;
6588 
6589   assert(is_inter_mode(mode));
6590 
6591   if (mode == NEWMV) {
6592     mode_cost = x->newmv_mode_cost[mode_ctx][0];
6593     return mode_cost;
6594   } else {
6595     mode_cost = x->newmv_mode_cost[mode_ctx][1];
6596     mode_ctx = (mode_context >> GLOBALMV_OFFSET) & GLOBALMV_CTX_MASK;
6597 
6598     if (mode == GLOBALMV) {
6599       mode_cost += x->zeromv_mode_cost[mode_ctx][0];
6600       return mode_cost;
6601     } else {
6602       mode_cost += x->zeromv_mode_cost[mode_ctx][1];
6603       mode_ctx = (mode_context >> REFMV_OFFSET) & REFMV_CTX_MASK;
6604       mode_cost += x->refmv_mode_cost[mode_ctx][mode != NEARESTMV];
6605       return mode_cost;
6606     }
6607   }
6608 }
6609 
get_interinter_compound_mask_rate(const MACROBLOCK * const x,const MB_MODE_INFO * const mbmi)6610 static int get_interinter_compound_mask_rate(const MACROBLOCK *const x,
6611                                              const MB_MODE_INFO *const mbmi) {
6612   switch (mbmi->interinter_comp.type) {
6613     case COMPOUND_AVERAGE: return 0;
6614     case COMPOUND_WEDGE:
6615       return get_interinter_wedge_bits(mbmi->sb_type) > 0
6616                  ? av1_cost_literal(1) +
6617                        x->wedge_idx_cost[mbmi->sb_type]
6618                                         [mbmi->interinter_comp.wedge_index]
6619                  : 0;
6620     case COMPOUND_DIFFWTD: return av1_cost_literal(1);
6621     default: assert(0); return 0;
6622   }
6623 }
6624 
mv_check_bounds(const MvLimits * mv_limits,const MV * mv)6625 static INLINE int mv_check_bounds(const MvLimits *mv_limits, const MV *mv) {
6626   return (mv->row >> 3) < mv_limits->row_min ||
6627          (mv->row >> 3) > mv_limits->row_max ||
6628          (mv->col >> 3) < mv_limits->col_min ||
6629          (mv->col >> 3) > mv_limits->col_max;
6630 }
6631 
get_single_mode(PREDICTION_MODE this_mode,int ref_idx,int is_comp_pred)6632 static INLINE PREDICTION_MODE get_single_mode(PREDICTION_MODE this_mode,
6633                                               int ref_idx, int is_comp_pred) {
6634   PREDICTION_MODE single_mode;
6635   if (is_comp_pred) {
6636     single_mode =
6637         ref_idx ? compound_ref1_mode(this_mode) : compound_ref0_mode(this_mode);
6638   } else {
6639     single_mode = this_mode;
6640   }
6641   return single_mode;
6642 }
6643 
joint_motion_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int_mv * cur_mv,int mi_row,int mi_col,int_mv * ref_mv_sub8x8[2],const uint8_t * mask,int mask_stride,int * rate_mv,const int block)6644 static void joint_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
6645                                 BLOCK_SIZE bsize, int_mv *cur_mv, int mi_row,
6646                                 int mi_col, int_mv *ref_mv_sub8x8[2],
6647                                 const uint8_t *mask, int mask_stride,
6648                                 int *rate_mv, const int block) {
6649   const AV1_COMMON *const cm = &cpi->common;
6650   const int num_planes = av1_num_planes(cm);
6651   const int pw = block_size_wide[bsize];
6652   const int ph = block_size_high[bsize];
6653   const int plane = 0;
6654   MACROBLOCKD *xd = &x->e_mbd;
6655   MB_MODE_INFO *mbmi = xd->mi[0];
6656   // This function should only ever be called for compound modes
6657   assert(has_second_ref(mbmi));
6658   const int_mv init_mv[2] = { cur_mv[0], cur_mv[1] };
6659   const int refs[2] = { mbmi->ref_frame[0], mbmi->ref_frame[1] };
6660   int_mv ref_mv[2];
6661   int ite, ref;
6662   // ic and ir are the 4x4 coordinates of the sub8x8 at index "block"
6663   const int ic = block & 1;
6664   const int ir = (block - ic) >> 1;
6665   struct macroblockd_plane *const pd = &xd->plane[0];
6666   const int p_col = ((mi_col * MI_SIZE) >> pd->subsampling_x) + 4 * ic;
6667   const int p_row = ((mi_row * MI_SIZE) >> pd->subsampling_y) + 4 * ir;
6668 
6669   ConvolveParams conv_params = get_conv_params(0, plane, xd->bd);
6670   conv_params.use_dist_wtd_comp_avg = 0;
6671   WarpTypesAllowed warp_types[2];
6672   for (ref = 0; ref < 2; ++ref) {
6673     const WarpedMotionParams *const wm =
6674         &xd->global_motion[xd->mi[0]->ref_frame[ref]];
6675     const int is_global = is_global_mv_block(xd->mi[0], wm->wmtype);
6676     warp_types[ref].global_warp_allowed = is_global;
6677     warp_types[ref].local_warp_allowed = mbmi->motion_mode == WARPED_CAUSAL;
6678   }
6679 
6680   // Do joint motion search in compound mode to get more accurate mv.
6681   struct buf_2d backup_yv12[2][MAX_MB_PLANE];
6682   int last_besterr[2] = { INT_MAX, INT_MAX };
6683   const YV12_BUFFER_CONFIG *const scaled_ref_frame[2] = {
6684     av1_get_scaled_ref_frame(cpi, refs[0]),
6685     av1_get_scaled_ref_frame(cpi, refs[1])
6686   };
6687 
6688   // Prediction buffer from second frame.
6689   DECLARE_ALIGNED(16, uint8_t, second_pred16[MAX_SB_SQUARE * sizeof(uint16_t)]);
6690   uint8_t *second_pred = get_buf_by_bd(xd, second_pred16);
6691   (void)ref_mv_sub8x8;
6692 
6693   MV *const best_mv = &x->best_mv.as_mv;
6694   const int search_range = SEARCH_RANGE_8P;
6695   const int sadpb = x->sadperbit16;
6696   // Allow joint search multiple times iteratively for each reference frame
6697   // and break out of the search loop if it couldn't find a better mv.
6698   for (ite = 0; ite < 4; ite++) {
6699     struct buf_2d ref_yv12[2];
6700     int bestsme = INT_MAX;
6701     MvLimits tmp_mv_limits = x->mv_limits;
6702     int id = ite % 2;  // Even iterations search in the first reference frame,
6703                        // odd iterations search in the second. The predictor
6704                        // found for the 'other' reference frame is factored in.
6705     if (ite >= 2 && cur_mv[!id].as_int == init_mv[!id].as_int) {
6706       if (cur_mv[id].as_int == init_mv[id].as_int) {
6707         break;
6708       } else {
6709         int_mv cur_int_mv, init_int_mv;
6710         cur_int_mv.as_mv.col = cur_mv[id].as_mv.col >> 3;
6711         cur_int_mv.as_mv.row = cur_mv[id].as_mv.row >> 3;
6712         init_int_mv.as_mv.row = init_mv[id].as_mv.row >> 3;
6713         init_int_mv.as_mv.col = init_mv[id].as_mv.col >> 3;
6714         if (cur_int_mv.as_int == init_int_mv.as_int) {
6715           break;
6716         }
6717       }
6718     }
6719     for (ref = 0; ref < 2; ++ref) {
6720       ref_mv[ref] = av1_get_ref_mv(x, ref);
6721       // Swap out the reference frame for a version that's been scaled to
6722       // match the resolution of the current frame, allowing the existing
6723       // motion search code to be used without additional modifications.
6724       if (scaled_ref_frame[ref]) {
6725         int i;
6726         for (i = 0; i < num_planes; i++)
6727           backup_yv12[ref][i] = xd->plane[i].pre[ref];
6728         av1_setup_pre_planes(xd, ref, scaled_ref_frame[ref], mi_row, mi_col,
6729                              NULL, num_planes);
6730       }
6731     }
6732 
6733     assert(IMPLIES(scaled_ref_frame[0] != NULL,
6734                    cm->width == scaled_ref_frame[0]->y_crop_width &&
6735                        cm->height == scaled_ref_frame[0]->y_crop_height));
6736     assert(IMPLIES(scaled_ref_frame[1] != NULL,
6737                    cm->width == scaled_ref_frame[1]->y_crop_width &&
6738                        cm->height == scaled_ref_frame[1]->y_crop_height));
6739 
6740     // Initialize based on (possibly scaled) prediction buffers.
6741     ref_yv12[0] = xd->plane[plane].pre[0];
6742     ref_yv12[1] = xd->plane[plane].pre[1];
6743 
6744     // Get the prediction block from the 'other' reference frame.
6745     const InterpFilters interp_filters = EIGHTTAP_REGULAR;
6746 
6747     // Since we have scaled the reference frames to match the size of the
6748     // current frame we must use a unit scaling factor during mode selection.
6749     av1_build_inter_predictor(ref_yv12[!id].buf, ref_yv12[!id].stride,
6750                               second_pred, pw, &cur_mv[!id].as_mv,
6751                               &cm->sf_identity, pw, ph, &conv_params,
6752                               interp_filters, &warp_types[!id], p_col, p_row,
6753                               plane, !id, MV_PRECISION_Q3, mi_col * MI_SIZE,
6754                               mi_row * MI_SIZE, xd, cm->allow_warped_motion);
6755 
6756     const int order_idx = id != 0;
6757     av1_dist_wtd_comp_weight_assign(
6758         cm, mbmi, order_idx, &xd->jcp_param.fwd_offset,
6759         &xd->jcp_param.bck_offset, &xd->jcp_param.use_dist_wtd_comp_avg, 1);
6760 
6761     // Do full-pixel compound motion search on the current reference frame.
6762     if (id) xd->plane[plane].pre[0] = ref_yv12[id];
6763     av1_set_mv_search_range(&x->mv_limits, &ref_mv[id].as_mv);
6764 
6765     // Use the mv result from the single mode as mv predictor.
6766     *best_mv = cur_mv[id].as_mv;
6767 
6768     best_mv->col >>= 3;
6769     best_mv->row >>= 3;
6770 
6771     // Small-range full-pixel motion search.
6772     bestsme = av1_refining_search_8p_c(x, sadpb, search_range,
6773                                        &cpi->fn_ptr[bsize], mask, mask_stride,
6774                                        id, &ref_mv[id].as_mv, second_pred);
6775     if (bestsme < INT_MAX) {
6776       if (mask)
6777         bestsme = av1_get_mvpred_mask_var(x, best_mv, &ref_mv[id].as_mv,
6778                                           second_pred, mask, mask_stride, id,
6779                                           &cpi->fn_ptr[bsize], 1);
6780       else
6781         bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv[id].as_mv,
6782                                         second_pred, &cpi->fn_ptr[bsize], 1);
6783     }
6784 
6785     x->mv_limits = tmp_mv_limits;
6786 
6787     // Restore the pointer to the first (possibly scaled) prediction buffer.
6788     if (id) xd->plane[plane].pre[0] = ref_yv12[0];
6789 
6790     for (ref = 0; ref < 2; ++ref) {
6791       if (scaled_ref_frame[ref]) {
6792         // Swap back the original buffers for subpel motion search.
6793         for (int i = 0; i < num_planes; i++) {
6794           xd->plane[i].pre[ref] = backup_yv12[ref][i];
6795         }
6796         // Re-initialize based on unscaled prediction buffers.
6797         ref_yv12[ref] = xd->plane[plane].pre[ref];
6798       }
6799     }
6800 
6801     // Do sub-pixel compound motion search on the current reference frame.
6802     if (id) xd->plane[plane].pre[0] = ref_yv12[id];
6803 
6804     if (cpi->common.cur_frame_force_integer_mv) {
6805       x->best_mv.as_mv.row *= 8;
6806       x->best_mv.as_mv.col *= 8;
6807     }
6808     if (bestsme < INT_MAX && cpi->common.cur_frame_force_integer_mv == 0) {
6809       int dis; /* TODO: use dis in distortion calculation later. */
6810       unsigned int sse;
6811       bestsme = cpi->find_fractional_mv_step(
6812           x, cm, mi_row, mi_col, &ref_mv[id].as_mv,
6813           cpi->common.allow_high_precision_mv, x->errorperbit,
6814           &cpi->fn_ptr[bsize], 0, cpi->sf.mv.subpel_iters_per_step, NULL,
6815           x->nmv_vec_cost, x->mv_cost_stack, &dis, &sse, second_pred, mask,
6816           mask_stride, id, pw, ph, cpi->sf.use_accurate_subpel_search, 1);
6817     }
6818 
6819     // Restore the pointer to the first prediction buffer.
6820     if (id) xd->plane[plane].pre[0] = ref_yv12[0];
6821     if (bestsme < last_besterr[id]) {
6822       cur_mv[id].as_mv = *best_mv;
6823       last_besterr[id] = bestsme;
6824     } else {
6825       break;
6826     }
6827   }
6828 
6829   *rate_mv = 0;
6830 
6831   for (ref = 0; ref < 2; ++ref) {
6832     const int_mv curr_ref_mv = av1_get_ref_mv(x, ref);
6833     *rate_mv +=
6834         av1_mv_bit_cost(&cur_mv[ref].as_mv, &curr_ref_mv.as_mv, x->nmv_vec_cost,
6835                         x->mv_cost_stack, MV_COST_WEIGHT);
6836   }
6837 }
6838 
estimate_ref_frame_costs(const AV1_COMMON * cm,const MACROBLOCKD * xd,const MACROBLOCK * x,int segment_id,unsigned int * ref_costs_single,unsigned int (* ref_costs_comp)[REF_FRAMES])6839 static void estimate_ref_frame_costs(
6840     const AV1_COMMON *cm, const MACROBLOCKD *xd, const MACROBLOCK *x,
6841     int segment_id, unsigned int *ref_costs_single,
6842     unsigned int (*ref_costs_comp)[REF_FRAMES]) {
6843   int seg_ref_active =
6844       segfeature_active(&cm->seg, segment_id, SEG_LVL_REF_FRAME);
6845   if (seg_ref_active) {
6846     memset(ref_costs_single, 0, REF_FRAMES * sizeof(*ref_costs_single));
6847     int ref_frame;
6848     for (ref_frame = 0; ref_frame < REF_FRAMES; ++ref_frame)
6849       memset(ref_costs_comp[ref_frame], 0,
6850              REF_FRAMES * sizeof((*ref_costs_comp)[0]));
6851   } else {
6852     int intra_inter_ctx = av1_get_intra_inter_context(xd);
6853     ref_costs_single[INTRA_FRAME] = x->intra_inter_cost[intra_inter_ctx][0];
6854     unsigned int base_cost = x->intra_inter_cost[intra_inter_ctx][1];
6855 
6856     for (int i = LAST_FRAME; i <= ALTREF_FRAME; ++i)
6857       ref_costs_single[i] = base_cost;
6858 
6859     const int ctx_p1 = av1_get_pred_context_single_ref_p1(xd);
6860     const int ctx_p2 = av1_get_pred_context_single_ref_p2(xd);
6861     const int ctx_p3 = av1_get_pred_context_single_ref_p3(xd);
6862     const int ctx_p4 = av1_get_pred_context_single_ref_p4(xd);
6863     const int ctx_p5 = av1_get_pred_context_single_ref_p5(xd);
6864     const int ctx_p6 = av1_get_pred_context_single_ref_p6(xd);
6865 
6866     // Determine cost of a single ref frame, where frame types are represented
6867     // by a tree:
6868     // Level 0: add cost whether this ref is a forward or backward ref
6869     ref_costs_single[LAST_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6870     ref_costs_single[LAST2_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6871     ref_costs_single[LAST3_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6872     ref_costs_single[GOLDEN_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6873     ref_costs_single[BWDREF_FRAME] += x->single_ref_cost[ctx_p1][0][1];
6874     ref_costs_single[ALTREF2_FRAME] += x->single_ref_cost[ctx_p1][0][1];
6875     ref_costs_single[ALTREF_FRAME] += x->single_ref_cost[ctx_p1][0][1];
6876 
6877     // Level 1: if this ref is forward ref,
6878     // add cost whether it is last/last2 or last3/golden
6879     ref_costs_single[LAST_FRAME] += x->single_ref_cost[ctx_p3][2][0];
6880     ref_costs_single[LAST2_FRAME] += x->single_ref_cost[ctx_p3][2][0];
6881     ref_costs_single[LAST3_FRAME] += x->single_ref_cost[ctx_p3][2][1];
6882     ref_costs_single[GOLDEN_FRAME] += x->single_ref_cost[ctx_p3][2][1];
6883 
6884     // Level 1: if this ref is backward ref
6885     // then add cost whether this ref is altref or backward ref
6886     ref_costs_single[BWDREF_FRAME] += x->single_ref_cost[ctx_p2][1][0];
6887     ref_costs_single[ALTREF2_FRAME] += x->single_ref_cost[ctx_p2][1][0];
6888     ref_costs_single[ALTREF_FRAME] += x->single_ref_cost[ctx_p2][1][1];
6889 
6890     // Level 2: further add cost whether this ref is last or last2
6891     ref_costs_single[LAST_FRAME] += x->single_ref_cost[ctx_p4][3][0];
6892     ref_costs_single[LAST2_FRAME] += x->single_ref_cost[ctx_p4][3][1];
6893 
6894     // Level 2: last3 or golden
6895     ref_costs_single[LAST3_FRAME] += x->single_ref_cost[ctx_p5][4][0];
6896     ref_costs_single[GOLDEN_FRAME] += x->single_ref_cost[ctx_p5][4][1];
6897 
6898     // Level 2: bwdref or altref2
6899     ref_costs_single[BWDREF_FRAME] += x->single_ref_cost[ctx_p6][5][0];
6900     ref_costs_single[ALTREF2_FRAME] += x->single_ref_cost[ctx_p6][5][1];
6901 
6902     if (cm->current_frame.reference_mode != SINGLE_REFERENCE) {
6903       // Similar to single ref, determine cost of compound ref frames.
6904       // cost_compound_refs = cost_first_ref + cost_second_ref
6905       const int bwdref_comp_ctx_p = av1_get_pred_context_comp_bwdref_p(xd);
6906       const int bwdref_comp_ctx_p1 = av1_get_pred_context_comp_bwdref_p1(xd);
6907       const int ref_comp_ctx_p = av1_get_pred_context_comp_ref_p(xd);
6908       const int ref_comp_ctx_p1 = av1_get_pred_context_comp_ref_p1(xd);
6909       const int ref_comp_ctx_p2 = av1_get_pred_context_comp_ref_p2(xd);
6910 
6911       const int comp_ref_type_ctx = av1_get_comp_reference_type_context(xd);
6912       unsigned int ref_bicomp_costs[REF_FRAMES] = { 0 };
6913 
6914       ref_bicomp_costs[LAST_FRAME] = ref_bicomp_costs[LAST2_FRAME] =
6915           ref_bicomp_costs[LAST3_FRAME] = ref_bicomp_costs[GOLDEN_FRAME] =
6916               base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][1];
6917       ref_bicomp_costs[BWDREF_FRAME] = ref_bicomp_costs[ALTREF2_FRAME] = 0;
6918       ref_bicomp_costs[ALTREF_FRAME] = 0;
6919 
6920       // cost of first ref frame
6921       ref_bicomp_costs[LAST_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][0];
6922       ref_bicomp_costs[LAST2_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][0];
6923       ref_bicomp_costs[LAST3_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][1];
6924       ref_bicomp_costs[GOLDEN_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][1];
6925 
6926       ref_bicomp_costs[LAST_FRAME] += x->comp_ref_cost[ref_comp_ctx_p1][1][0];
6927       ref_bicomp_costs[LAST2_FRAME] += x->comp_ref_cost[ref_comp_ctx_p1][1][1];
6928 
6929       ref_bicomp_costs[LAST3_FRAME] += x->comp_ref_cost[ref_comp_ctx_p2][2][0];
6930       ref_bicomp_costs[GOLDEN_FRAME] += x->comp_ref_cost[ref_comp_ctx_p2][2][1];
6931 
6932       // cost of second ref frame
6933       ref_bicomp_costs[BWDREF_FRAME] +=
6934           x->comp_bwdref_cost[bwdref_comp_ctx_p][0][0];
6935       ref_bicomp_costs[ALTREF2_FRAME] +=
6936           x->comp_bwdref_cost[bwdref_comp_ctx_p][0][0];
6937       ref_bicomp_costs[ALTREF_FRAME] +=
6938           x->comp_bwdref_cost[bwdref_comp_ctx_p][0][1];
6939 
6940       ref_bicomp_costs[BWDREF_FRAME] +=
6941           x->comp_bwdref_cost[bwdref_comp_ctx_p1][1][0];
6942       ref_bicomp_costs[ALTREF2_FRAME] +=
6943           x->comp_bwdref_cost[bwdref_comp_ctx_p1][1][1];
6944 
6945       // cost: if one ref frame is forward ref, the other ref is backward ref
6946       int ref0, ref1;
6947       for (ref0 = LAST_FRAME; ref0 <= GOLDEN_FRAME; ++ref0) {
6948         for (ref1 = BWDREF_FRAME; ref1 <= ALTREF_FRAME; ++ref1) {
6949           ref_costs_comp[ref0][ref1] =
6950               ref_bicomp_costs[ref0] + ref_bicomp_costs[ref1];
6951         }
6952       }
6953 
6954       // cost: if both ref frames are the same side.
6955       const int uni_comp_ref_ctx_p = av1_get_pred_context_uni_comp_ref_p(xd);
6956       const int uni_comp_ref_ctx_p1 = av1_get_pred_context_uni_comp_ref_p1(xd);
6957       const int uni_comp_ref_ctx_p2 = av1_get_pred_context_uni_comp_ref_p2(xd);
6958       ref_costs_comp[LAST_FRAME][LAST2_FRAME] =
6959           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6960           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][0] +
6961           x->uni_comp_ref_cost[uni_comp_ref_ctx_p1][1][0];
6962       ref_costs_comp[LAST_FRAME][LAST3_FRAME] =
6963           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6964           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][0] +
6965           x->uni_comp_ref_cost[uni_comp_ref_ctx_p1][1][1] +
6966           x->uni_comp_ref_cost[uni_comp_ref_ctx_p2][2][0];
6967       ref_costs_comp[LAST_FRAME][GOLDEN_FRAME] =
6968           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6969           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][0] +
6970           x->uni_comp_ref_cost[uni_comp_ref_ctx_p1][1][1] +
6971           x->uni_comp_ref_cost[uni_comp_ref_ctx_p2][2][1];
6972       ref_costs_comp[BWDREF_FRAME][ALTREF_FRAME] =
6973           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6974           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][1];
6975     } else {
6976       int ref0, ref1;
6977       for (ref0 = LAST_FRAME; ref0 <= GOLDEN_FRAME; ++ref0) {
6978         for (ref1 = BWDREF_FRAME; ref1 <= ALTREF_FRAME; ++ref1)
6979           ref_costs_comp[ref0][ref1] = 512;
6980       }
6981       ref_costs_comp[LAST_FRAME][LAST2_FRAME] = 512;
6982       ref_costs_comp[LAST_FRAME][LAST3_FRAME] = 512;
6983       ref_costs_comp[LAST_FRAME][GOLDEN_FRAME] = 512;
6984       ref_costs_comp[BWDREF_FRAME][ALTREF_FRAME] = 512;
6985     }
6986   }
6987 }
6988 
store_coding_context(MACROBLOCK * x,PICK_MODE_CONTEXT * ctx,int mode_index,int64_t comp_pred_diff[REFERENCE_MODES],int skippable)6989 static void store_coding_context(MACROBLOCK *x, PICK_MODE_CONTEXT *ctx,
6990                                  int mode_index,
6991                                  int64_t comp_pred_diff[REFERENCE_MODES],
6992                                  int skippable) {
6993   MACROBLOCKD *const xd = &x->e_mbd;
6994 
6995   // Take a snapshot of the coding context so it can be
6996   // restored if we decide to encode this way
6997   ctx->skip = x->skip;
6998   ctx->skippable = skippable;
6999   ctx->best_mode_index = mode_index;
7000   ctx->mic = *xd->mi[0];
7001   ctx->mbmi_ext = *x->mbmi_ext;
7002   ctx->single_pred_diff = (int)comp_pred_diff[SINGLE_REFERENCE];
7003   ctx->comp_pred_diff = (int)comp_pred_diff[COMPOUND_REFERENCE];
7004   ctx->hybrid_pred_diff = (int)comp_pred_diff[REFERENCE_MODE_SELECT];
7005 }
7006 
setup_buffer_ref_mvs_inter(const AV1_COMP * const cpi,MACROBLOCK * x,MV_REFERENCE_FRAME ref_frame,BLOCK_SIZE block_size,int mi_row,int mi_col,struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE])7007 static void setup_buffer_ref_mvs_inter(
7008     const AV1_COMP *const cpi, MACROBLOCK *x, MV_REFERENCE_FRAME ref_frame,
7009     BLOCK_SIZE block_size, int mi_row, int mi_col,
7010     struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE]) {
7011   const AV1_COMMON *cm = &cpi->common;
7012   const int num_planes = av1_num_planes(cm);
7013   const YV12_BUFFER_CONFIG *scaled_ref_frame =
7014       av1_get_scaled_ref_frame(cpi, ref_frame);
7015   MACROBLOCKD *const xd = &x->e_mbd;
7016   MB_MODE_INFO *const mbmi = xd->mi[0];
7017   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
7018   const struct scale_factors *const sf =
7019       get_ref_scale_factors_const(cm, ref_frame);
7020   const YV12_BUFFER_CONFIG *yv12 = get_ref_frame_yv12_buf(cm, ref_frame);
7021   assert(yv12 != NULL);
7022 
7023   if (scaled_ref_frame) {
7024     // Setup pred block based on scaled reference, because av1_mv_pred() doesn't
7025     // support scaling.
7026     av1_setup_pred_block(xd, yv12_mb[ref_frame], scaled_ref_frame, mi_row,
7027                          mi_col, NULL, NULL, num_planes);
7028   } else {
7029     av1_setup_pred_block(xd, yv12_mb[ref_frame], yv12, mi_row, mi_col, sf, sf,
7030                          num_planes);
7031   }
7032 
7033   // Gets an initial list of candidate vectors from neighbours and orders them
7034   av1_find_mv_refs(cm, xd, mbmi, ref_frame, mbmi_ext->ref_mv_count,
7035                    mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
7036                    mi_col, mbmi_ext->mode_context);
7037 
7038   // Further refinement that is encode side only to test the top few candidates
7039   // in full and choose the best as the center point for subsequent searches.
7040   // The current implementation doesn't support scaling.
7041   av1_mv_pred(cpi, x, yv12_mb[ref_frame][0].buf, yv12_mb[ref_frame][0].stride,
7042               ref_frame, block_size);
7043 
7044   // Go back to unscaled reference.
7045   if (scaled_ref_frame) {
7046     // We had temporarily setup pred block based on scaled reference above. Go
7047     // back to unscaled reference now, for subsequent use.
7048     av1_setup_pred_block(xd, yv12_mb[ref_frame], yv12, mi_row, mi_col, sf, sf,
7049                          num_planes);
7050   }
7051 }
7052 
single_motion_search(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int ref_idx,int * rate_mv)7053 static void single_motion_search(const AV1_COMP *const cpi, MACROBLOCK *x,
7054                                  BLOCK_SIZE bsize, int mi_row, int mi_col,
7055                                  int ref_idx, int *rate_mv) {
7056   MACROBLOCKD *xd = &x->e_mbd;
7057   const AV1_COMMON *cm = &cpi->common;
7058   const int num_planes = av1_num_planes(cm);
7059   MB_MODE_INFO *mbmi = xd->mi[0];
7060   struct buf_2d backup_yv12[MAX_MB_PLANE] = { { 0, 0, 0, 0, 0 } };
7061   int bestsme = INT_MAX;
7062   int step_param;
7063   int sadpb = x->sadperbit16;
7064   MV mvp_full;
7065   int ref = mbmi->ref_frame[ref_idx];
7066   MV ref_mv = av1_get_ref_mv(x, ref_idx).as_mv;
7067 
7068   MvLimits tmp_mv_limits = x->mv_limits;
7069   int cost_list[5];
7070 
7071   const YV12_BUFFER_CONFIG *scaled_ref_frame =
7072       av1_get_scaled_ref_frame(cpi, ref);
7073 
7074   if (scaled_ref_frame) {
7075     // Swap out the reference frame for a version that's been scaled to
7076     // match the resolution of the current frame, allowing the existing
7077     // full-pixel motion search code to be used without additional
7078     // modifications.
7079     for (int i = 0; i < num_planes; i++) {
7080       backup_yv12[i] = xd->plane[i].pre[ref_idx];
7081     }
7082     av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL,
7083                          num_planes);
7084   }
7085 
7086   // Work out the size of the first step in the mv step search.
7087   // 0 here is maximum length first step. 1 is AOMMAX >> 1 etc.
7088   if (cpi->sf.mv.auto_mv_step_size && cm->show_frame) {
7089     // Take the weighted average of the step_params based on the last frame's
7090     // max mv magnitude and that based on the best ref mvs of the current
7091     // block for the given reference.
7092     step_param =
7093         (av1_init_search_range(x->max_mv_context[ref]) + cpi->mv_step_param) /
7094         2;
7095   } else {
7096     step_param = cpi->mv_step_param;
7097   }
7098 
7099   if (cpi->sf.adaptive_motion_search && bsize < cm->seq_params.sb_size) {
7100     int boffset =
7101         2 * (mi_size_wide_log2[cm->seq_params.sb_size] -
7102              AOMMIN(mi_size_high_log2[bsize], mi_size_wide_log2[bsize]));
7103     step_param = AOMMAX(step_param, boffset);
7104   }
7105 
7106   if (cpi->sf.adaptive_motion_search) {
7107     int bwl = mi_size_wide_log2[bsize];
7108     int bhl = mi_size_high_log2[bsize];
7109     int tlevel = x->pred_mv_sad[ref] >> (bwl + bhl + 4);
7110 
7111     if (tlevel < 5) {
7112       step_param += 2;
7113       step_param = AOMMIN(step_param, MAX_MVSEARCH_STEPS - 1);
7114     }
7115 
7116     // prev_mv_sad is not setup for dynamically scaled frames.
7117     if (cpi->oxcf.resize_mode != RESIZE_RANDOM) {
7118       int i;
7119       for (i = LAST_FRAME; i <= ALTREF_FRAME && cm->show_frame; ++i) {
7120         if ((x->pred_mv_sad[ref] >> 3) > x->pred_mv_sad[i]) {
7121           x->pred_mv[ref].row = 0;
7122           x->pred_mv[ref].col = 0;
7123           x->best_mv.as_int = INVALID_MV;
7124 
7125           if (scaled_ref_frame) {
7126             // Swap back the original buffers before returning.
7127             for (int j = 0; j < num_planes; ++j)
7128               xd->plane[j].pre[ref_idx] = backup_yv12[j];
7129           }
7130           return;
7131         }
7132       }
7133     }
7134   }
7135 
7136   // Note: MV limits are modified here. Always restore the original values
7137   // after full-pixel motion search.
7138   av1_set_mv_search_range(&x->mv_limits, &ref_mv);
7139 
7140   if (mbmi->motion_mode != SIMPLE_TRANSLATION)
7141     mvp_full = mbmi->mv[0].as_mv;
7142   else
7143     mvp_full = ref_mv;
7144 
7145   mvp_full.col >>= 3;
7146   mvp_full.row >>= 3;
7147 
7148   x->best_mv.as_int = x->second_best_mv.as_int = INVALID_MV;
7149 
7150   switch (mbmi->motion_mode) {
7151     case SIMPLE_TRANSLATION:
7152       bestsme = av1_full_pixel_search(
7153           cpi, x, bsize, &mvp_full, step_param, cpi->sf.mv.search_method, 0,
7154           sadpb, cond_cost_list(cpi, cost_list), &ref_mv, INT_MAX, 1,
7155           (MI_SIZE * mi_col), (MI_SIZE * mi_row), 0, &cpi->ss_cfg[SS_CFG_SRC]);
7156       break;
7157     case OBMC_CAUSAL:
7158       bestsme = av1_obmc_full_pixel_search(
7159           cpi, x, &mvp_full, step_param, sadpb,
7160           MAX_MVSEARCH_STEPS - 1 - step_param, 1, &cpi->fn_ptr[bsize], &ref_mv,
7161           &(x->best_mv.as_mv), 0, &cpi->ss_cfg[SS_CFG_SRC]);
7162       break;
7163     default: assert(0 && "Invalid motion mode!\n");
7164   }
7165 
7166   if (scaled_ref_frame) {
7167     // Swap back the original buffers for subpel motion search.
7168     for (int i = 0; i < num_planes; i++) {
7169       xd->plane[i].pre[ref_idx] = backup_yv12[i];
7170     }
7171   }
7172 
7173   x->mv_limits = tmp_mv_limits;
7174 
7175   if (cpi->common.cur_frame_force_integer_mv) {
7176     x->best_mv.as_mv.row *= 8;
7177     x->best_mv.as_mv.col *= 8;
7178   }
7179   const int use_fractional_mv =
7180       bestsme < INT_MAX && cpi->common.cur_frame_force_integer_mv == 0;
7181   if (use_fractional_mv) {
7182     int dis; /* TODO: use dis in distortion calculation later. */
7183     switch (mbmi->motion_mode) {
7184       case SIMPLE_TRANSLATION:
7185         if (cpi->sf.use_accurate_subpel_search) {
7186           int best_mv_var;
7187           const int try_second = x->second_best_mv.as_int != INVALID_MV &&
7188                                  x->second_best_mv.as_int != x->best_mv.as_int;
7189           const int pw = block_size_wide[bsize];
7190           const int ph = block_size_high[bsize];
7191           best_mv_var = cpi->find_fractional_mv_step(
7192               x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
7193               x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
7194               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
7195               x->nmv_vec_cost, x->mv_cost_stack, &dis, &x->pred_sse[ref], NULL,
7196               NULL, 0, 0, pw, ph, cpi->sf.use_accurate_subpel_search, 1);
7197 
7198           if (try_second) {
7199             const int minc =
7200                 AOMMAX(x->mv_limits.col_min * 8, ref_mv.col - MV_MAX);
7201             const int maxc =
7202                 AOMMIN(x->mv_limits.col_max * 8, ref_mv.col + MV_MAX);
7203             const int minr =
7204                 AOMMAX(x->mv_limits.row_min * 8, ref_mv.row - MV_MAX);
7205             const int maxr =
7206                 AOMMIN(x->mv_limits.row_max * 8, ref_mv.row + MV_MAX);
7207             int this_var;
7208             MV best_mv = x->best_mv.as_mv;
7209 
7210             x->best_mv = x->second_best_mv;
7211             if (x->best_mv.as_mv.row * 8 <= maxr &&
7212                 x->best_mv.as_mv.row * 8 >= minr &&
7213                 x->best_mv.as_mv.col * 8 <= maxc &&
7214                 x->best_mv.as_mv.col * 8 >= minc) {
7215               this_var = cpi->find_fractional_mv_step(
7216                   x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
7217                   x->errorperbit, &cpi->fn_ptr[bsize],
7218                   cpi->sf.mv.subpel_force_stop,
7219                   cpi->sf.mv.subpel_iters_per_step,
7220                   cond_cost_list(cpi, cost_list), x->nmv_vec_cost,
7221                   x->mv_cost_stack, &dis, &x->pred_sse[ref], NULL, NULL, 0, 0,
7222                   pw, ph, cpi->sf.use_accurate_subpel_search, 0);
7223               if (this_var < best_mv_var) best_mv = x->best_mv.as_mv;
7224               x->best_mv.as_mv = best_mv;
7225             }
7226           }
7227         } else {
7228           cpi->find_fractional_mv_step(
7229               x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
7230               x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
7231               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
7232               x->nmv_vec_cost, x->mv_cost_stack, &dis, &x->pred_sse[ref], NULL,
7233               NULL, 0, 0, 0, 0, 0, 1);
7234         }
7235         break;
7236       case OBMC_CAUSAL:
7237         av1_find_best_obmc_sub_pixel_tree_up(
7238             x, cm, mi_row, mi_col, &x->best_mv.as_mv, &ref_mv,
7239             cm->allow_high_precision_mv, x->errorperbit, &cpi->fn_ptr[bsize],
7240             cpi->sf.mv.subpel_force_stop, cpi->sf.mv.subpel_iters_per_step,
7241             x->nmv_vec_cost, x->mv_cost_stack, &dis, &x->pred_sse[ref], 0,
7242             cpi->sf.use_accurate_subpel_search);
7243         break;
7244       default: assert(0 && "Invalid motion mode!\n");
7245     }
7246   }
7247   *rate_mv = av1_mv_bit_cost(&x->best_mv.as_mv, &ref_mv, x->nmv_vec_cost,
7248                              x->mv_cost_stack, MV_COST_WEIGHT);
7249 
7250   if (cpi->sf.adaptive_motion_search && mbmi->motion_mode == SIMPLE_TRANSLATION)
7251     x->pred_mv[ref] = x->best_mv.as_mv;
7252 }
7253 
restore_dst_buf(MACROBLOCKD * xd,const BUFFER_SET dst,const int num_planes)7254 static INLINE void restore_dst_buf(MACROBLOCKD *xd, const BUFFER_SET dst,
7255                                    const int num_planes) {
7256   for (int i = 0; i < num_planes; i++) {
7257     xd->plane[i].dst.buf = dst.plane[i];
7258     xd->plane[i].dst.stride = dst.stride[i];
7259   }
7260 }
7261 
build_second_inter_pred(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,const MV * other_mv,int mi_row,int mi_col,const int block,int ref_idx,uint8_t * second_pred)7262 static void build_second_inter_pred(const AV1_COMP *cpi, MACROBLOCK *x,
7263                                     BLOCK_SIZE bsize, const MV *other_mv,
7264                                     int mi_row, int mi_col, const int block,
7265                                     int ref_idx, uint8_t *second_pred) {
7266   const AV1_COMMON *const cm = &cpi->common;
7267   const int pw = block_size_wide[bsize];
7268   const int ph = block_size_high[bsize];
7269   MACROBLOCKD *xd = &x->e_mbd;
7270   MB_MODE_INFO *mbmi = xd->mi[0];
7271   const int other_ref = mbmi->ref_frame[!ref_idx];
7272   struct macroblockd_plane *const pd = &xd->plane[0];
7273   // ic and ir are the 4x4 coordinates of the sub8x8 at index "block"
7274   const int ic = block & 1;
7275   const int ir = (block - ic) >> 1;
7276   const int p_col = ((mi_col * MI_SIZE) >> pd->subsampling_x) + 4 * ic;
7277   const int p_row = ((mi_row * MI_SIZE) >> pd->subsampling_y) + 4 * ir;
7278   const WarpedMotionParams *const wm = &xd->global_motion[other_ref];
7279   int is_global = is_global_mv_block(xd->mi[0], wm->wmtype);
7280 
7281   // This function should only ever be called for compound modes
7282   assert(has_second_ref(mbmi));
7283 
7284   const int plane = 0;
7285   struct buf_2d ref_yv12 = xd->plane[plane].pre[!ref_idx];
7286 
7287   struct scale_factors sf;
7288   av1_setup_scale_factors_for_frame(&sf, ref_yv12.width, ref_yv12.height,
7289                                     cm->width, cm->height);
7290 
7291   ConvolveParams conv_params = get_conv_params(0, plane, xd->bd);
7292   WarpTypesAllowed warp_types;
7293   warp_types.global_warp_allowed = is_global;
7294   warp_types.local_warp_allowed = mbmi->motion_mode == WARPED_CAUSAL;
7295 
7296   // Get the prediction block from the 'other' reference frame.
7297   av1_build_inter_predictor(ref_yv12.buf, ref_yv12.stride, second_pred, pw,
7298                             other_mv, &sf, pw, ph, &conv_params,
7299                             mbmi->interp_filters, &warp_types, p_col, p_row,
7300                             plane, !ref_idx, MV_PRECISION_Q3, mi_col * MI_SIZE,
7301                             mi_row * MI_SIZE, xd, cm->allow_warped_motion);
7302 
7303   av1_dist_wtd_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset,
7304                                   &xd->jcp_param.bck_offset,
7305                                   &xd->jcp_param.use_dist_wtd_comp_avg, 1);
7306 }
7307 
7308 // Search for the best mv for one component of a compound,
7309 // given that the other component is fixed.
compound_single_motion_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,MV * this_mv,int mi_row,int mi_col,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,int * rate_mv,int ref_idx)7310 static void compound_single_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
7311                                           BLOCK_SIZE bsize, MV *this_mv,
7312                                           int mi_row, int mi_col,
7313                                           const uint8_t *second_pred,
7314                                           const uint8_t *mask, int mask_stride,
7315                                           int *rate_mv, int ref_idx) {
7316   const AV1_COMMON *const cm = &cpi->common;
7317   const int num_planes = av1_num_planes(cm);
7318   const int pw = block_size_wide[bsize];
7319   const int ph = block_size_high[bsize];
7320   MACROBLOCKD *xd = &x->e_mbd;
7321   MB_MODE_INFO *mbmi = xd->mi[0];
7322   const int ref = mbmi->ref_frame[ref_idx];
7323   const int_mv ref_mv = av1_get_ref_mv(x, ref_idx);
7324   struct macroblockd_plane *const pd = &xd->plane[0];
7325 
7326   struct buf_2d backup_yv12[MAX_MB_PLANE];
7327   const YV12_BUFFER_CONFIG *const scaled_ref_frame =
7328       av1_get_scaled_ref_frame(cpi, ref);
7329 
7330   // Check that this is either an interinter or an interintra block
7331   assert(has_second_ref(mbmi) || (ref_idx == 0 && is_interintra_mode(mbmi)));
7332 
7333   // Store the first prediction buffer.
7334   struct buf_2d orig_yv12;
7335   if (ref_idx) {
7336     orig_yv12 = pd->pre[0];
7337     pd->pre[0] = pd->pre[ref_idx];
7338   }
7339 
7340   if (scaled_ref_frame) {
7341     int i;
7342     // Swap out the reference frame for a version that's been scaled to
7343     // match the resolution of the current frame, allowing the existing
7344     // full-pixel motion search code to be used without additional
7345     // modifications.
7346     for (i = 0; i < num_planes; i++) backup_yv12[i] = xd->plane[i].pre[ref_idx];
7347     av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL,
7348                          num_planes);
7349   }
7350 
7351   int bestsme = INT_MAX;
7352   int sadpb = x->sadperbit16;
7353   MV *const best_mv = &x->best_mv.as_mv;
7354   int search_range = SEARCH_RANGE_8P;
7355 
7356   MvLimits tmp_mv_limits = x->mv_limits;
7357 
7358   // Do compound motion search on the current reference frame.
7359   av1_set_mv_search_range(&x->mv_limits, &ref_mv.as_mv);
7360 
7361   // Use the mv result from the single mode as mv predictor.
7362   *best_mv = *this_mv;
7363 
7364   best_mv->col >>= 3;
7365   best_mv->row >>= 3;
7366 
7367   // Small-range full-pixel motion search.
7368   bestsme = av1_refining_search_8p_c(x, sadpb, search_range,
7369                                      &cpi->fn_ptr[bsize], mask, mask_stride,
7370                                      ref_idx, &ref_mv.as_mv, second_pred);
7371   if (bestsme < INT_MAX) {
7372     if (mask)
7373       bestsme =
7374           av1_get_mvpred_mask_var(x, best_mv, &ref_mv.as_mv, second_pred, mask,
7375                                   mask_stride, ref_idx, &cpi->fn_ptr[bsize], 1);
7376     else
7377       bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv.as_mv, second_pred,
7378                                       &cpi->fn_ptr[bsize], 1);
7379   }
7380 
7381   x->mv_limits = tmp_mv_limits;
7382 
7383   if (scaled_ref_frame) {
7384     // Swap back the original buffers for subpel motion search.
7385     for (int i = 0; i < num_planes; i++) {
7386       xd->plane[i].pre[ref_idx] = backup_yv12[i];
7387     }
7388   }
7389 
7390   if (cpi->common.cur_frame_force_integer_mv) {
7391     x->best_mv.as_mv.row *= 8;
7392     x->best_mv.as_mv.col *= 8;
7393   }
7394   const int use_fractional_mv =
7395       bestsme < INT_MAX && cpi->common.cur_frame_force_integer_mv == 0;
7396   if (use_fractional_mv) {
7397     int dis; /* TODO: use dis in distortion calculation later. */
7398     unsigned int sse;
7399     bestsme = cpi->find_fractional_mv_step(
7400         x, cm, mi_row, mi_col, &ref_mv.as_mv,
7401         cpi->common.allow_high_precision_mv, x->errorperbit,
7402         &cpi->fn_ptr[bsize], 0, cpi->sf.mv.subpel_iters_per_step, NULL,
7403         x->nmv_vec_cost, x->mv_cost_stack, &dis, &sse, second_pred, mask,
7404         mask_stride, ref_idx, pw, ph, cpi->sf.use_accurate_subpel_search, 1);
7405   }
7406 
7407   // Restore the pointer to the first unscaled prediction buffer.
7408   if (ref_idx) pd->pre[0] = orig_yv12;
7409 
7410   if (bestsme < INT_MAX) *this_mv = *best_mv;
7411 
7412   *rate_mv = 0;
7413 
7414   *rate_mv += av1_mv_bit_cost(this_mv, &ref_mv.as_mv, x->nmv_vec_cost,
7415                               x->mv_cost_stack, MV_COST_WEIGHT);
7416 }
7417 
7418 // Wrapper for compound_single_motion_search, for the common case
7419 // where the second prediction is also an inter mode.
compound_single_motion_search_interinter(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int_mv * cur_mv,int mi_row,int mi_col,const uint8_t * mask,int mask_stride,int * rate_mv,const int block,int ref_idx)7420 static void compound_single_motion_search_interinter(
7421     const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int_mv *cur_mv,
7422     int mi_row, int mi_col, const uint8_t *mask, int mask_stride, int *rate_mv,
7423     const int block, int ref_idx) {
7424   MACROBLOCKD *xd = &x->e_mbd;
7425   // This function should only ever be called for compound modes
7426   assert(has_second_ref(xd->mi[0]));
7427 
7428   // Prediction buffer from second frame.
7429   DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE]);
7430   uint8_t *second_pred;
7431   if (is_cur_buf_hbd(xd))
7432     second_pred = CONVERT_TO_BYTEPTR(second_pred_alloc_16);
7433   else
7434     second_pred = (uint8_t *)second_pred_alloc_16;
7435 
7436   MV *this_mv = &cur_mv[ref_idx].as_mv;
7437   const MV *other_mv = &cur_mv[!ref_idx].as_mv;
7438 
7439   build_second_inter_pred(cpi, x, bsize, other_mv, mi_row, mi_col, block,
7440                           ref_idx, second_pred);
7441 
7442   compound_single_motion_search(cpi, x, bsize, this_mv, mi_row, mi_col,
7443                                 second_pred, mask, mask_stride, rate_mv,
7444                                 ref_idx);
7445 }
7446 
do_masked_motion_search_indexed(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const INTERINTER_COMPOUND_DATA * const comp_data,BLOCK_SIZE bsize,int mi_row,int mi_col,int_mv * tmp_mv,int * rate_mv,int which)7447 static void do_masked_motion_search_indexed(
7448     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
7449     const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
7450     int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int which) {
7451   // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
7452   MACROBLOCKD *xd = &x->e_mbd;
7453   MB_MODE_INFO *mbmi = xd->mi[0];
7454   BLOCK_SIZE sb_type = mbmi->sb_type;
7455   const uint8_t *mask;
7456   const int mask_stride = block_size_wide[bsize];
7457 
7458   mask = av1_get_compound_type_mask(comp_data, sb_type);
7459 
7460   tmp_mv[0].as_int = cur_mv[0].as_int;
7461   tmp_mv[1].as_int = cur_mv[1].as_int;
7462   if (which == 0 || which == 1) {
7463     compound_single_motion_search_interinter(cpi, x, bsize, tmp_mv, mi_row,
7464                                              mi_col, mask, mask_stride, rate_mv,
7465                                              0, which);
7466   } else if (which == 2) {
7467     joint_motion_search(cpi, x, bsize, tmp_mv, mi_row, mi_col, NULL, mask,
7468                         mask_stride, rate_mv, 0);
7469   }
7470 }
7471 
7472 #define USE_DISCOUNT_NEWMV_TEST 0
7473 #if USE_DISCOUNT_NEWMV_TEST
7474 // In some situations we want to discount the apparent cost of a new motion
7475 // vector. Where there is a subtle motion field and especially where there is
7476 // low spatial complexity then it can be hard to cover the cost of a new motion
7477 // vector in a single block, even if that motion vector reduces distortion.
7478 // However, once established that vector may be usable through the nearest and
7479 // near mv modes to reduce distortion in subsequent blocks and also improve
7480 // visual quality.
7481 #define NEW_MV_DISCOUNT_FACTOR 8
7482 static INLINE void get_this_mv(int_mv *this_mv, PREDICTION_MODE this_mode,
7483                                int ref_idx, int ref_mv_idx,
7484                                const MV_REFERENCE_FRAME *ref_frame,
7485                                const MB_MODE_INFO_EXT *mbmi_ext);
discount_newmv_test(const AV1_COMP * const cpi,const MACROBLOCK * x,PREDICTION_MODE this_mode,int_mv this_mv)7486 static int discount_newmv_test(const AV1_COMP *const cpi, const MACROBLOCK *x,
7487                                PREDICTION_MODE this_mode, int_mv this_mv) {
7488   if (this_mode == NEWMV && this_mv.as_int != 0 &&
7489       !cpi->rc.is_src_frame_alt_ref) {
7490     // Only discount new_mv when nearst_mv and all near_mv are zero, and the
7491     // new_mv is not equal to global_mv
7492     const AV1_COMMON *const cm = &cpi->common;
7493     const MACROBLOCKD *const xd = &x->e_mbd;
7494     const MB_MODE_INFO *const mbmi = xd->mi[0];
7495     const MV_REFERENCE_FRAME tmp_ref_frames[2] = { mbmi->ref_frame[0],
7496                                                    NONE_FRAME };
7497     const uint8_t ref_frame_type = av1_ref_frame_type(tmp_ref_frames);
7498     int_mv nearest_mv;
7499     get_this_mv(&nearest_mv, NEARESTMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
7500     int ret = nearest_mv.as_int == 0;
7501     for (int ref_mv_idx = 0;
7502          ref_mv_idx < x->mbmi_ext->ref_mv_count[ref_frame_type]; ++ref_mv_idx) {
7503       int_mv near_mv;
7504       get_this_mv(&near_mv, NEARMV, 0, ref_mv_idx, tmp_ref_frames, x->mbmi_ext);
7505       ret &= near_mv.as_int == 0;
7506     }
7507     if (cm->global_motion[tmp_ref_frames[0]].wmtype <= TRANSLATION) {
7508       int_mv global_mv;
7509       get_this_mv(&global_mv, GLOBALMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
7510       ret &= global_mv.as_int != this_mv.as_int;
7511     }
7512     return ret;
7513   }
7514   return 0;
7515 }
7516 #endif
7517 
7518 #define LEFT_TOP_MARGIN ((AOM_BORDER_IN_PIXELS - AOM_INTERP_EXTEND) << 3)
7519 #define RIGHT_BOTTOM_MARGIN ((AOM_BORDER_IN_PIXELS - AOM_INTERP_EXTEND) << 3)
7520 
7521 // TODO(jingning): this mv clamping function should be block size dependent.
clamp_mv2(MV * mv,const MACROBLOCKD * xd)7522 static INLINE void clamp_mv2(MV *mv, const MACROBLOCKD *xd) {
7523   clamp_mv(mv, xd->mb_to_left_edge - LEFT_TOP_MARGIN,
7524            xd->mb_to_right_edge + RIGHT_BOTTOM_MARGIN,
7525            xd->mb_to_top_edge - LEFT_TOP_MARGIN,
7526            xd->mb_to_bottom_edge + RIGHT_BOTTOM_MARGIN);
7527 }
7528 
estimate_wedge_sign(const AV1_COMP * cpi,const MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * pred0,int stride0,const uint8_t * pred1,int stride1)7529 static int estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
7530                                const BLOCK_SIZE bsize, const uint8_t *pred0,
7531                                int stride0, const uint8_t *pred1, int stride1) {
7532   static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
7533     //                            4X4
7534     BLOCK_INVALID,
7535     // 4X8,        8X4,           8X8
7536     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
7537     // 8X16,       16X8,          16X16
7538     BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
7539     // 16X32,      32X16,         32X32
7540     BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
7541     // 32X64,      64X32,         64X64
7542     BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
7543     // 64x128,     128x64,        128x128
7544     BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
7545     // 4X16,       16X4,          8X32
7546     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
7547     // 32X8,       16X64,         64X16
7548     BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
7549   };
7550   const struct macroblock_plane *const p = &x->plane[0];
7551   const uint8_t *src = p->src.buf;
7552   int src_stride = p->src.stride;
7553   const int bw = block_size_wide[bsize];
7554   const int bh = block_size_high[bsize];
7555   uint32_t esq[2][4];
7556   int64_t tl, br;
7557 
7558   const BLOCK_SIZE f_index = split_qtr[bsize];
7559   assert(f_index != BLOCK_INVALID);
7560 
7561   if (is_cur_buf_hbd(&x->e_mbd)) {
7562     pred0 = CONVERT_TO_BYTEPTR(pred0);
7563     pred1 = CONVERT_TO_BYTEPTR(pred1);
7564   }
7565 
7566   cpi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
7567   cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, pred0 + bw / 2, stride0,
7568                           &esq[0][1]);
7569   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride, src_stride,
7570                           pred0 + bh / 2 * stride0, stride0, &esq[0][2]);
7571   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride + bw / 2, src_stride,
7572                           pred0 + bh / 2 * stride0 + bw / 2, stride0,
7573                           &esq[0][3]);
7574   cpi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
7575   cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, pred1 + bw / 2, stride1,
7576                           &esq[1][1]);
7577   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride, src_stride,
7578                           pred1 + bh / 2 * stride1, stride0, &esq[1][2]);
7579   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride + bw / 2, src_stride,
7580                           pred1 + bh / 2 * stride1 + bw / 2, stride0,
7581                           &esq[1][3]);
7582 
7583   tl = ((int64_t)esq[0][0] + esq[0][1] + esq[0][2]) -
7584        ((int64_t)esq[1][0] + esq[1][1] + esq[1][2]);
7585   br = ((int64_t)esq[1][3] + esq[1][1] + esq[1][2]) -
7586        ((int64_t)esq[0][3] + esq[0][1] + esq[0][2]);
7587   return (tl + br > 0);
7588 }
7589 
7590 // Choose the best wedge index and sign
pick_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const int16_t * const residual1,const int16_t * const diff10,int * const best_wedge_sign,int * const best_wedge_index)7591 static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
7592                           const BLOCK_SIZE bsize, const uint8_t *const p0,
7593                           const int16_t *const residual1,
7594                           const int16_t *const diff10,
7595                           int *const best_wedge_sign,
7596                           int *const best_wedge_index) {
7597   const MACROBLOCKD *const xd = &x->e_mbd;
7598   const struct buf_2d *const src = &x->plane[0].src;
7599   const int bw = block_size_wide[bsize];
7600   const int bh = block_size_high[bsize];
7601   const int N = bw * bh;
7602   assert(N >= 64);
7603   int rate;
7604   int64_t dist;
7605   int64_t rd, best_rd = INT64_MAX;
7606   int wedge_index;
7607   int wedge_sign;
7608   int wedge_types = (1 << get_wedge_bits_lookup(bsize));
7609   const uint8_t *mask;
7610   uint64_t sse;
7611   const int hbd = is_cur_buf_hbd(xd);
7612   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
7613 
7614   DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]);  // src - pred0
7615   if (hbd) {
7616     aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
7617                               CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
7618   } else {
7619     aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
7620   }
7621 
7622   int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
7623                         (int64_t)aom_sum_squares_i16(residual1, N)) *
7624                        (1 << WEDGE_WEIGHT_BITS) / 2;
7625   int16_t *ds = residual0;
7626 
7627   av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
7628 
7629   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
7630     mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
7631 
7632     wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
7633 
7634     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
7635     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
7636     sse = ROUND_POWER_OF_TWO(sse, bd_round);
7637 
7638     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
7639                                                   &rate, &dist);
7640     // int rate2;
7641     // int64_t dist2;
7642     // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
7643     // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
7644     // sse, rate, dist, rate2, dist2); dist = dist2;
7645     // rate = rate2;
7646 
7647     rate += x->wedge_idx_cost[bsize][wedge_index];
7648     rd = RDCOST(x->rdmult, rate, dist);
7649 
7650     if (rd < best_rd) {
7651       *best_wedge_index = wedge_index;
7652       *best_wedge_sign = wedge_sign;
7653       best_rd = rd;
7654     }
7655   }
7656 
7657   return best_rd -
7658          RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
7659 }
7660 
7661 // Choose the best wedge index the specified sign
pick_wedge_fixed_sign(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const int16_t * const residual1,const int16_t * const diff10,const int wedge_sign,int * const best_wedge_index)7662 static int64_t pick_wedge_fixed_sign(const AV1_COMP *const cpi,
7663                                      const MACROBLOCK *const x,
7664                                      const BLOCK_SIZE bsize,
7665                                      const int16_t *const residual1,
7666                                      const int16_t *const diff10,
7667                                      const int wedge_sign,
7668                                      int *const best_wedge_index) {
7669   const MACROBLOCKD *const xd = &x->e_mbd;
7670 
7671   const int bw = block_size_wide[bsize];
7672   const int bh = block_size_high[bsize];
7673   const int N = bw * bh;
7674   assert(N >= 64);
7675   int rate;
7676   int64_t dist;
7677   int64_t rd, best_rd = INT64_MAX;
7678   int wedge_index;
7679   int wedge_types = (1 << get_wedge_bits_lookup(bsize));
7680   const uint8_t *mask;
7681   uint64_t sse;
7682   const int hbd = is_cur_buf_hbd(xd);
7683   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
7684   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
7685     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
7686     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
7687     sse = ROUND_POWER_OF_TWO(sse, bd_round);
7688 
7689     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
7690                                                   &rate, &dist);
7691     rate += x->wedge_idx_cost[bsize][wedge_index];
7692     rd = RDCOST(x->rdmult, rate, dist);
7693 
7694     if (rd < best_rd) {
7695       *best_wedge_index = wedge_index;
7696       best_rd = rd;
7697     }
7698   }
7699   return best_rd -
7700          RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
7701 }
7702 
pick_interinter_wedge(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10)7703 static int64_t pick_interinter_wedge(
7704     const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
7705     const uint8_t *const p0, const uint8_t *const p1,
7706     const int16_t *const residual1, const int16_t *const diff10) {
7707   MACROBLOCKD *const xd = &x->e_mbd;
7708   MB_MODE_INFO *const mbmi = xd->mi[0];
7709   const int bw = block_size_wide[bsize];
7710 
7711   int64_t rd;
7712   int wedge_index = -1;
7713   int wedge_sign = 0;
7714 
7715   assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
7716   assert(cpi->common.seq_params.enable_masked_compound);
7717 
7718   if (cpi->sf.fast_wedge_sign_estimate) {
7719     wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
7720     rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
7721                                &wedge_index);
7722   } else {
7723     rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
7724                     &wedge_index);
7725   }
7726 
7727   mbmi->interinter_comp.wedge_sign = wedge_sign;
7728   mbmi->interinter_comp.wedge_index = wedge_index;
7729   return rd;
7730 }
7731 
pick_interinter_seg(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10)7732 static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
7733                                    MACROBLOCK *const x, const BLOCK_SIZE bsize,
7734                                    const uint8_t *const p0,
7735                                    const uint8_t *const p1,
7736                                    const int16_t *const residual1,
7737                                    const int16_t *const diff10) {
7738   MACROBLOCKD *const xd = &x->e_mbd;
7739   MB_MODE_INFO *const mbmi = xd->mi[0];
7740   const int bw = block_size_wide[bsize];
7741   const int bh = block_size_high[bsize];
7742   const int N = 1 << num_pels_log2_lookup[bsize];
7743   int rate;
7744   int64_t dist;
7745   DIFFWTD_MASK_TYPE cur_mask_type;
7746   int64_t best_rd = INT64_MAX;
7747   DIFFWTD_MASK_TYPE best_mask_type = 0;
7748   const int hbd = is_cur_buf_hbd(xd);
7749   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
7750   DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
7751   uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
7752   // try each mask type and its inverse
7753   for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
7754     // build mask and inverse
7755     if (hbd)
7756       av1_build_compound_diffwtd_mask_highbd(
7757           tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
7758           CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
7759     else
7760       av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
7761                                       p0, bw, p1, bw, bh, bw);
7762 
7763     // compute rd for mask
7764     uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
7765                                                 tmp_mask[cur_mask_type], N);
7766     sse = ROUND_POWER_OF_TWO(sse, bd_round);
7767 
7768     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
7769                                                   &rate, &dist);
7770     const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
7771 
7772     if (rd0 < best_rd) {
7773       best_mask_type = cur_mask_type;
7774       best_rd = rd0;
7775     }
7776   }
7777   mbmi->interinter_comp.mask_type = best_mask_type;
7778   if (best_mask_type == DIFFWTD_38_INV) {
7779     memcpy(xd->seg_mask, seg_mask, N * 2);
7780   }
7781   return best_rd;
7782 }
7783 
pick_interintra_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1)7784 static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
7785                                      const MACROBLOCK *const x,
7786                                      const BLOCK_SIZE bsize,
7787                                      const uint8_t *const p0,
7788                                      const uint8_t *const p1) {
7789   const MACROBLOCKD *const xd = &x->e_mbd;
7790   MB_MODE_INFO *const mbmi = xd->mi[0];
7791   assert(is_interintra_wedge_used(bsize));
7792   assert(cpi->common.seq_params.enable_interintra_compound);
7793 
7794   const struct buf_2d *const src = &x->plane[0].src;
7795   const int bw = block_size_wide[bsize];
7796   const int bh = block_size_high[bsize];
7797   DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]);  // src - pred1
7798   DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]);     // pred1 - pred0
7799   if (is_cur_buf_hbd(xd)) {
7800     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
7801                               CONVERT_TO_BYTEPTR(p1), bw, xd->bd);
7802     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
7803                               CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
7804   } else {
7805     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
7806     aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
7807   }
7808   int wedge_index = -1;
7809   int64_t rd =
7810       pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0, &wedge_index);
7811 
7812   mbmi->interintra_wedge_sign = 0;
7813   mbmi->interintra_wedge_index = wedge_index;
7814   return rd;
7815 }
7816 
pick_interinter_mask(const AV1_COMP * const cpi,MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10)7817 static int64_t pick_interinter_mask(const AV1_COMP *const cpi, MACROBLOCK *x,
7818                                     const BLOCK_SIZE bsize,
7819                                     const uint8_t *const p0,
7820                                     const uint8_t *const p1,
7821                                     const int16_t *const residual1,
7822                                     const int16_t *const diff10) {
7823   const COMPOUND_TYPE compound_type = x->e_mbd.mi[0]->interinter_comp.type;
7824   switch (compound_type) {
7825     case COMPOUND_WEDGE:
7826       return pick_interinter_wedge(cpi, x, bsize, p0, p1, residual1, diff10);
7827     case COMPOUND_DIFFWTD:
7828       return pick_interinter_seg(cpi, x, bsize, p0, p1, residual1, diff10);
7829     default: assert(0); return 0;
7830   }
7831 }
7832 
interinter_compound_motion_search(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int mi_row,int mi_col)7833 static int interinter_compound_motion_search(const AV1_COMP *const cpi,
7834                                              MACROBLOCK *x,
7835                                              const int_mv *const cur_mv,
7836                                              const BLOCK_SIZE bsize,
7837                                              const PREDICTION_MODE this_mode,
7838                                              int mi_row, int mi_col) {
7839   MACROBLOCKD *const xd = &x->e_mbd;
7840   MB_MODE_INFO *const mbmi = xd->mi[0];
7841   int_mv tmp_mv[2];
7842   int tmp_rate_mv = 0;
7843   mbmi->interinter_comp.seg_mask = xd->seg_mask;
7844   const INTERINTER_COMPOUND_DATA *compound_data = &mbmi->interinter_comp;
7845 
7846   if (this_mode == NEW_NEWMV) {
7847     do_masked_motion_search_indexed(cpi, x, cur_mv, compound_data, bsize,
7848                                     mi_row, mi_col, tmp_mv, &tmp_rate_mv, 2);
7849     mbmi->mv[0].as_int = tmp_mv[0].as_int;
7850     mbmi->mv[1].as_int = tmp_mv[1].as_int;
7851   } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) {
7852     do_masked_motion_search_indexed(cpi, x, cur_mv, compound_data, bsize,
7853                                     mi_row, mi_col, tmp_mv, &tmp_rate_mv, 0);
7854     mbmi->mv[0].as_int = tmp_mv[0].as_int;
7855   } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
7856     do_masked_motion_search_indexed(cpi, x, cur_mv, compound_data, bsize,
7857                                     mi_row, mi_col, tmp_mv, &tmp_rate_mv, 1);
7858     mbmi->mv[1].as_int = tmp_mv[1].as_int;
7859   }
7860   return tmp_rate_mv;
7861 }
7862 
get_inter_predictors_masked_compound(const AV1_COMP * const cpi,MACROBLOCK * x,const BLOCK_SIZE bsize,int mi_row,int mi_col,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides)7863 static void get_inter_predictors_masked_compound(
7864     const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
7865     int mi_row, int mi_col, uint8_t **preds0, uint8_t **preds1,
7866     int16_t *residual1, int16_t *diff10, int *strides) {
7867   const AV1_COMMON *cm = &cpi->common;
7868   MACROBLOCKD *xd = &x->e_mbd;
7869   const int bw = block_size_wide[bsize];
7870   const int bh = block_size_high[bsize];
7871   int can_use_previous = cm->allow_warped_motion;
7872   // get inter predictors to use for masked compound modes
7873   av1_build_inter_predictors_for_planes_single_buf(
7874       xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides, can_use_previous);
7875   av1_build_inter_predictors_for_planes_single_buf(
7876       xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides, can_use_previous);
7877   const struct buf_2d *const src = &x->plane[0].src;
7878   if (is_cur_buf_hbd(xd)) {
7879     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
7880                               CONVERT_TO_BYTEPTR(*preds1), bw, xd->bd);
7881     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
7882                               bw, CONVERT_TO_BYTEPTR(*preds0), bw, xd->bd);
7883   } else {
7884     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
7885                        bw);
7886     aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
7887   }
7888 }
7889 
build_and_cost_compound_type(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int * rs2,int rate_mv,const BUFFER_SET * ctx,int * out_rate_mv,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides,int mi_row,int mi_col,int mode_rate,int64_t ref_best_rd,int * calc_pred_masked_compound,int32_t * comp_rate,int64_t * comp_dist,int64_t * const comp_model_rd,const int64_t comp_best_model_rd,int64_t * const comp_model_rd_cur)7890 static int64_t build_and_cost_compound_type(
7891     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
7892     const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
7893     int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
7894     uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
7895     int mi_row, int mi_col, int mode_rate, int64_t ref_best_rd,
7896     int *calc_pred_masked_compound, int32_t *comp_rate, int64_t *comp_dist,
7897     int64_t *const comp_model_rd, const int64_t comp_best_model_rd,
7898     int64_t *const comp_model_rd_cur) {
7899   const AV1_COMMON *const cm = &cpi->common;
7900   MACROBLOCKD *xd = &x->e_mbd;
7901   MB_MODE_INFO *const mbmi = xd->mi[0];
7902   int64_t best_rd_cur = INT64_MAX;
7903   int64_t rd = INT64_MAX;
7904   const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
7905   int rate_sum, tmp_skip_txfm_sb;
7906   int64_t dist_sum, tmp_skip_sse_sb;
7907 
7908   // TODO(any): Save pred and mask calculation as well into records. However
7909   // this may increase memory requirements as compound segment mask needs to be
7910   // stored in each record.
7911   if (*calc_pred_masked_compound) {
7912     get_inter_predictors_masked_compound(cpi, x, bsize, mi_row, mi_col, preds0,
7913                                          preds1, residual1, diff10, strides);
7914     *calc_pred_masked_compound = 0;
7915   }
7916   if (cpi->sf.prune_wedge_pred_diff_based && compound_type == COMPOUND_WEDGE) {
7917     unsigned int sse;
7918     if (is_cur_buf_hbd(xd))
7919       (void)cpi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
7920                                   CONVERT_TO_BYTEPTR(*preds1), *strides, &sse);
7921     else
7922       (void)cpi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides, &sse);
7923     const unsigned int mse =
7924         ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
7925     // If two predictors are very similar, skip wedge compound mode search
7926     if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
7927       *comp_model_rd_cur = INT64_MAX;
7928       return INT64_MAX;
7929     }
7930   }
7931 
7932   best_rd_cur =
7933       pick_interinter_mask(cpi, x, bsize, *preds0, *preds1, residual1, diff10);
7934   *rs2 += get_interinter_compound_mask_rate(x, mbmi);
7935   best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
7936 
7937   // Although the true rate_mv might be different after motion search, but it
7938   // is unlikely to be the best mode considering the transform rd cost and other
7939   // mode overhead cost
7940   int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
7941   if (mode_rd > ref_best_rd) {
7942     *comp_model_rd_cur = INT64_MAX;
7943     return INT64_MAX;
7944   }
7945 
7946   // Reuse data if matching record is found
7947   if (comp_rate[compound_type] == INT_MAX) {
7948     if (have_newmv_in_inter_mode(this_mode) &&
7949         compound_type == COMPOUND_WEDGE &&
7950         !cpi->sf.disable_interinter_wedge_newmv_search) {
7951       *out_rate_mv = interinter_compound_motion_search(
7952           cpi, x, cur_mv, bsize, this_mode, mi_row, mi_col);
7953       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
7954                                     AOM_PLANE_Y, AOM_PLANE_Y);
7955 
7956       model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
7957           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
7958           &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
7959       rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
7960       *comp_model_rd_cur = rd;
7961       if (rd >= best_rd_cur) {
7962         mbmi->mv[0].as_int = cur_mv[0].as_int;
7963         mbmi->mv[1].as_int = cur_mv[1].as_int;
7964         *out_rate_mv = rate_mv;
7965         av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
7966                                                  strides, preds1, strides);
7967         *comp_model_rd_cur = best_rd_cur;
7968       }
7969     } else {
7970       *out_rate_mv = rate_mv;
7971       av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
7972                                                preds1, strides);
7973       model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
7974           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
7975           &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
7976       *comp_model_rd_cur =
7977           RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
7978     }
7979 
7980     RD_STATS rd_stats;
7981 
7982     if (cpi->sf.prune_comp_type_by_model_rd &&
7983         (*comp_model_rd_cur > comp_best_model_rd) &&
7984         comp_best_model_rd != INT64_MAX) {
7985       *comp_model_rd_cur = INT64_MAX;
7986       return INT64_MAX;
7987     }
7988     rd = estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &rd_stats);
7989     if (rd != INT64_MAX) {
7990       rd =
7991           RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
7992       // Backup rate and distortion for future reuse
7993       comp_rate[compound_type] = rd_stats.rate;
7994       comp_dist[compound_type] = rd_stats.dist;
7995       comp_model_rd[compound_type] = *comp_model_rd_cur;
7996     }
7997   } else {
7998     assert(comp_dist[compound_type] != INT64_MAX);
7999     // When disable_interinter_wedge_newmv_search is set, motion refinement is
8000     // disabled. Hence rate and distortion can be reused in this case as well
8001     assert(IMPLIES(have_newmv_in_inter_mode(this_mode),
8002                    cpi->sf.disable_interinter_wedge_newmv_search));
8003     assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
8004     assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
8005     *out_rate_mv = rate_mv;
8006     // Calculate RD cost based on stored stats
8007     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
8008                 comp_dist[compound_type]);
8009     *comp_model_rd_cur = comp_model_rd[compound_type];
8010   }
8011   return rd;
8012 }
8013 
8014 typedef struct {
8015   // OBMC secondary prediction buffers and respective strides
8016   uint8_t *above_pred_buf[MAX_MB_PLANE];
8017   int above_pred_stride[MAX_MB_PLANE];
8018   uint8_t *left_pred_buf[MAX_MB_PLANE];
8019   int left_pred_stride[MAX_MB_PLANE];
8020   int_mv (*single_newmv)[REF_FRAMES];
8021   // Pointer to array of motion vectors to use for each ref and their rates
8022   // Should point to first of 2 arrays in 2D array
8023   int (*single_newmv_rate)[REF_FRAMES];
8024   int (*single_newmv_valid)[REF_FRAMES];
8025   // Pointer to array of predicted rate-distortion
8026   // Should point to first of 2 arrays in 2D array
8027   int64_t (*modelled_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
8028   InterpFilter single_filter[MB_MODE_COUNT][REF_FRAMES];
8029   int ref_frame_cost;
8030   int single_comp_cost;
8031   int64_t (*simple_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
8032   int skip_motion_mode;
8033   INTERINTRA_MODE *inter_intra_mode;
8034   int single_ref_first_pass;
8035   SimpleRDState *simple_rd_state;
8036 } HandleInterModeArgs;
8037 
8038 /* If the current mode shares the same mv with other modes with higher cost,
8039  * skip this mode. */
skip_repeated_mv(const AV1_COMMON * const cm,const MACROBLOCK * const x,PREDICTION_MODE this_mode,const MV_REFERENCE_FRAME ref_frames[2],InterModeSearchState * search_state)8040 static int skip_repeated_mv(const AV1_COMMON *const cm,
8041                             const MACROBLOCK *const x,
8042                             PREDICTION_MODE this_mode,
8043                             const MV_REFERENCE_FRAME ref_frames[2],
8044                             InterModeSearchState *search_state) {
8045   const int is_comp_pred = ref_frames[1] > INTRA_FRAME;
8046   const uint8_t ref_frame_type = av1_ref_frame_type(ref_frames);
8047   const MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
8048   const int ref_mv_count = mbmi_ext->ref_mv_count[ref_frame_type];
8049   PREDICTION_MODE compare_mode = MB_MODE_COUNT;
8050   if (!is_comp_pred) {
8051     if (this_mode == NEARMV) {
8052       if (ref_mv_count == 0) {
8053         // NEARMV has the same motion vector as NEARESTMV
8054         compare_mode = NEARESTMV;
8055       }
8056       if (ref_mv_count == 1 &&
8057           cm->global_motion[ref_frames[0]].wmtype <= TRANSLATION) {
8058         // NEARMV has the same motion vector as GLOBALMV
8059         compare_mode = GLOBALMV;
8060       }
8061     }
8062     if (this_mode == GLOBALMV) {
8063       if (ref_mv_count == 0 &&
8064           cm->global_motion[ref_frames[0]].wmtype <= TRANSLATION) {
8065         // GLOBALMV has the same motion vector as NEARESTMV
8066         compare_mode = NEARESTMV;
8067       }
8068       if (ref_mv_count == 1) {
8069         // GLOBALMV has the same motion vector as NEARMV
8070         compare_mode = NEARMV;
8071       }
8072     }
8073 
8074     if (compare_mode != MB_MODE_COUNT) {
8075       // Use modelled_rd to check whether compare mode was searched
8076       if (search_state->modelled_rd[compare_mode][0][ref_frames[0]] !=
8077           INT64_MAX) {
8078         const int16_t mode_ctx =
8079             av1_mode_context_analyzer(mbmi_ext->mode_context, ref_frames);
8080         const int compare_cost = cost_mv_ref(x, compare_mode, mode_ctx);
8081         const int this_cost = cost_mv_ref(x, this_mode, mode_ctx);
8082 
8083         // Only skip if the mode cost is larger than compare mode cost
8084         if (this_cost > compare_cost) {
8085           search_state->modelled_rd[this_mode][0][ref_frames[0]] =
8086               search_state->modelled_rd[compare_mode][0][ref_frames[0]];
8087           return 1;
8088         }
8089       }
8090     }
8091   }
8092   return 0;
8093 }
8094 
clamp_and_check_mv(int_mv * out_mv,int_mv in_mv,const AV1_COMMON * cm,const MACROBLOCK * x)8095 static INLINE int clamp_and_check_mv(int_mv *out_mv, int_mv in_mv,
8096                                      const AV1_COMMON *cm,
8097                                      const MACROBLOCK *x) {
8098   const MACROBLOCKD *const xd = &x->e_mbd;
8099   *out_mv = in_mv;
8100   lower_mv_precision(&out_mv->as_mv, cm->allow_high_precision_mv,
8101                      cm->cur_frame_force_integer_mv);
8102   clamp_mv2(&out_mv->as_mv, xd);
8103   return !mv_check_bounds(&x->mv_limits, &out_mv->as_mv);
8104 }
8105 
handle_newmv(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,int_mv * cur_mv,const int mi_row,const int mi_col,int * const rate_mv,HandleInterModeArgs * const args)8106 static int64_t handle_newmv(const AV1_COMP *const cpi, MACROBLOCK *const x,
8107                             const BLOCK_SIZE bsize, int_mv *cur_mv,
8108                             const int mi_row, const int mi_col,
8109                             int *const rate_mv,
8110                             HandleInterModeArgs *const args) {
8111   const MACROBLOCKD *const xd = &x->e_mbd;
8112   const MB_MODE_INFO *const mbmi = xd->mi[0];
8113   const int is_comp_pred = has_second_ref(mbmi);
8114   const PREDICTION_MODE this_mode = mbmi->mode;
8115   const int refs[2] = { mbmi->ref_frame[0],
8116                         mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1] };
8117   const int ref_mv_idx = mbmi->ref_mv_idx;
8118   int i;
8119 
8120   (void)args;
8121 
8122   if (is_comp_pred) {
8123     if (this_mode == NEW_NEWMV) {
8124       cur_mv[0].as_int = args->single_newmv[ref_mv_idx][refs[0]].as_int;
8125       cur_mv[1].as_int = args->single_newmv[ref_mv_idx][refs[1]].as_int;
8126 
8127       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
8128         joint_motion_search(cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, NULL,
8129                             0, rate_mv, 0);
8130       } else {
8131         *rate_mv = 0;
8132         for (i = 0; i < 2; ++i) {
8133           const int_mv ref_mv = av1_get_ref_mv(x, i);
8134           *rate_mv +=
8135               av1_mv_bit_cost(&cur_mv[i].as_mv, &ref_mv.as_mv, x->nmv_vec_cost,
8136                               x->mv_cost_stack, MV_COST_WEIGHT);
8137         }
8138       }
8139     } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
8140       cur_mv[1].as_int = args->single_newmv[ref_mv_idx][refs[1]].as_int;
8141       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
8142         compound_single_motion_search_interinter(
8143             cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, 0, rate_mv, 0, 1);
8144       } else {
8145         const int_mv ref_mv = av1_get_ref_mv(x, 1);
8146         *rate_mv =
8147             av1_mv_bit_cost(&cur_mv[1].as_mv, &ref_mv.as_mv, x->nmv_vec_cost,
8148                             x->mv_cost_stack, MV_COST_WEIGHT);
8149       }
8150     } else {
8151       assert(this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV);
8152       cur_mv[0].as_int = args->single_newmv[ref_mv_idx][refs[0]].as_int;
8153       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
8154         compound_single_motion_search_interinter(
8155             cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, 0, rate_mv, 0, 0);
8156       } else {
8157         const int_mv ref_mv = av1_get_ref_mv(x, 0);
8158         *rate_mv =
8159             av1_mv_bit_cost(&cur_mv[0].as_mv, &ref_mv.as_mv, x->nmv_vec_cost,
8160                             x->mv_cost_stack, MV_COST_WEIGHT);
8161       }
8162     }
8163   } else {
8164     single_motion_search(cpi, x, bsize, mi_row, mi_col, 0, rate_mv);
8165     if (x->best_mv.as_int == INVALID_MV) return INT64_MAX;
8166 
8167     args->single_newmv[ref_mv_idx][refs[0]] = x->best_mv;
8168     args->single_newmv_rate[ref_mv_idx][refs[0]] = *rate_mv;
8169     args->single_newmv_valid[ref_mv_idx][refs[0]] = 1;
8170 
8171     cur_mv[0].as_int = x->best_mv.as_int;
8172 
8173 #if USE_DISCOUNT_NEWMV_TEST
8174     // Estimate the rate implications of a new mv but discount this
8175     // under certain circumstances where we want to help initiate a weak
8176     // motion field, where the distortion gain for a single block may not
8177     // be enough to overcome the cost of a new mv.
8178     if (discount_newmv_test(cpi, x, this_mode, x->best_mv)) {
8179       *rate_mv = AOMMAX(*rate_mv / NEW_MV_DISCOUNT_FACTOR, 1);
8180     }
8181 #endif
8182   }
8183 
8184   return 0;
8185 }
8186 
swap_dst_buf(MACROBLOCKD * xd,const BUFFER_SET * dst_bufs[2],int num_planes)8187 static INLINE void swap_dst_buf(MACROBLOCKD *xd, const BUFFER_SET *dst_bufs[2],
8188                                 int num_planes) {
8189   const BUFFER_SET *buf0 = dst_bufs[0];
8190   dst_bufs[0] = dst_bufs[1];
8191   dst_bufs[1] = buf0;
8192   restore_dst_buf(xd, *dst_bufs[0], num_planes);
8193 }
8194 
get_switchable_rate(MACROBLOCK * const x,const InterpFilters filters,const int ctx[2])8195 static INLINE int get_switchable_rate(MACROBLOCK *const x,
8196                                       const InterpFilters filters,
8197                                       const int ctx[2]) {
8198   int inter_filter_cost;
8199   const InterpFilter filter0 = av1_extract_interp_filter(filters, 0);
8200   const InterpFilter filter1 = av1_extract_interp_filter(filters, 1);
8201   inter_filter_cost = x->switchable_interp_costs[ctx[0]][filter0];
8202   inter_filter_cost += x->switchable_interp_costs[ctx[1]][filter1];
8203   return SWITCHABLE_INTERP_RATE_FACTOR * inter_filter_cost;
8204 }
8205 
8206 // calculate the rdcost of given interpolation_filter
interpolation_filter_rd(MACROBLOCK * const x,const AV1_COMP * const cpi,const TileDataEnc * tile_data,BLOCK_SIZE bsize,int mi_row,int mi_col,const BUFFER_SET * const orig_dst,int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,const BUFFER_SET * dst_bufs[2],int filter_idx,const int switchable_ctx[2],const int skip_pred,int * rate,int64_t * dist)8207 static INLINE int64_t interpolation_filter_rd(
8208     MACROBLOCK *const x, const AV1_COMP *const cpi,
8209     const TileDataEnc *tile_data, BLOCK_SIZE bsize, int mi_row, int mi_col,
8210     const BUFFER_SET *const orig_dst, int64_t *const rd,
8211     int *const switchable_rate, int *const skip_txfm_sb,
8212     int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2], int filter_idx,
8213     const int switchable_ctx[2], const int skip_pred, int *rate,
8214     int64_t *dist) {
8215   const AV1_COMMON *cm = &cpi->common;
8216   const int num_planes = av1_num_planes(cm);
8217   MACROBLOCKD *const xd = &x->e_mbd;
8218   MB_MODE_INFO *const mbmi = xd->mi[0];
8219   int tmp_rate[2], tmp_skip_sb[2] = { 1, 1 };
8220   int64_t tmp_dist[2], tmp_skip_sse[2] = { 0, 0 };
8221 
8222   const InterpFilters last_best = mbmi->interp_filters;
8223   mbmi->interp_filters = filter_sets[filter_idx];
8224   const int tmp_rs =
8225       get_switchable_rate(x, mbmi->interp_filters, switchable_ctx);
8226 
8227   int64_t min_rd = RDCOST(x->rdmult, tmp_rs, 0);
8228   if (min_rd > *rd) {
8229     mbmi->interp_filters = last_best;
8230     return 0;
8231   }
8232 
8233   (void)tile_data;
8234 
8235   assert(skip_pred != 2);
8236   assert((skip_pred >= 0) && (skip_pred <= cpi->default_interp_skip_flags));
8237   assert(rate[0] >= 0);
8238   assert(dist[0] >= 0);
8239   assert((skip_txfm_sb[0] == 0) || (skip_txfm_sb[0] == 1));
8240   assert(skip_sse_sb[0] >= 0);
8241   assert(rate[1] >= 0);
8242   assert(dist[1] >= 0);
8243   assert((skip_txfm_sb[1] == 0) || (skip_txfm_sb[1] == 1));
8244   assert(skip_sse_sb[1] >= 0);
8245 
8246   if (skip_pred != cpi->default_interp_skip_flags) {
8247     if (skip_pred != DEFAULT_LUMA_INTERP_SKIP_FLAG) {
8248       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
8249                                     AOM_PLANE_Y, AOM_PLANE_Y);
8250 #if CONFIG_COLLECT_RD_STATS == 3
8251       RD_STATS rd_stats_y;
8252       pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col,
8253                             INT64_MAX);
8254       PrintPredictionUnitStats(cpi, tile_data, x, &rd_stats_y, bsize);
8255 #endif  // CONFIG_COLLECT_RD_STATS == 3
8256       model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
8257           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &tmp_rate[0], &tmp_dist[0],
8258           &tmp_skip_sb[0], &tmp_skip_sse[0], NULL, NULL, NULL);
8259       tmp_rate[1] = tmp_rate[0];
8260       tmp_dist[1] = tmp_dist[0];
8261     } else {
8262       // only luma MC is skipped
8263       tmp_rate[1] = rate[0];
8264       tmp_dist[1] = dist[0];
8265     }
8266     if (num_planes > 1) {
8267       for (int plane = 1; plane < num_planes; ++plane) {
8268         int tmp_rate_uv, tmp_skip_sb_uv;
8269         int64_t tmp_dist_uv, tmp_skip_sse_uv;
8270         int64_t tmp_rd = RDCOST(x->rdmult, tmp_rs + tmp_rate[1], tmp_dist[1]);
8271         if (tmp_rd >= *rd) {
8272           mbmi->interp_filters = last_best;
8273           return 0;
8274         }
8275         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
8276                                       plane, plane);
8277         model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
8278             cpi, bsize, x, xd, plane, plane, mi_row, mi_col, &tmp_rate_uv,
8279             &tmp_dist_uv, &tmp_skip_sb_uv, &tmp_skip_sse_uv, NULL, NULL, NULL);
8280         tmp_rate[1] =
8281             (int)AOMMIN(((int64_t)tmp_rate[1] + (int64_t)tmp_rate_uv), INT_MAX);
8282         tmp_dist[1] += tmp_dist_uv;
8283         tmp_skip_sb[1] &= tmp_skip_sb_uv;
8284         tmp_skip_sse[1] += tmp_skip_sse_uv;
8285       }
8286     }
8287   } else {
8288     // both luma and chroma MC is skipped
8289     tmp_rate[1] = rate[1];
8290     tmp_dist[1] = dist[1];
8291   }
8292   int64_t tmp_rd = RDCOST(x->rdmult, tmp_rs + tmp_rate[1], tmp_dist[1]);
8293 
8294   if (tmp_rd < *rd) {
8295     *rd = tmp_rd;
8296     *switchable_rate = tmp_rs;
8297     if (skip_pred != cpi->default_interp_skip_flags) {
8298       if (skip_pred == 0) {
8299         // Overwrite the data as current filter is the best one
8300         tmp_skip_sb[1] = tmp_skip_sb[0] & tmp_skip_sb[1];
8301         tmp_skip_sse[1] = tmp_skip_sse[0] + tmp_skip_sse[1];
8302         memcpy(rate, tmp_rate, sizeof(*rate) * 2);
8303         memcpy(dist, tmp_dist, sizeof(*dist) * 2);
8304         memcpy(skip_txfm_sb, tmp_skip_sb, sizeof(*skip_txfm_sb) * 2);
8305         memcpy(skip_sse_sb, tmp_skip_sse, sizeof(*skip_sse_sb) * 2);
8306         // As luma MC data is computed, no need to recompute after the search
8307         x->recalc_luma_mc_data = 0;
8308       } else if (skip_pred == DEFAULT_LUMA_INTERP_SKIP_FLAG) {
8309         // As luma MC data is not computed, update of luma data can be skipped
8310         rate[1] = tmp_rate[1];
8311         dist[1] = tmp_dist[1];
8312         skip_txfm_sb[1] = skip_txfm_sb[0] & tmp_skip_sb[1];
8313         skip_sse_sb[1] = skip_sse_sb[0] + tmp_skip_sse[1];
8314         // As luma MC data is not recomputed and current filter is the best,
8315         // indicate the possibility of recomputing MC data
8316         // If current buffer contains valid MC data, toggle to indicate that
8317         // luma MC data needs to be recomputed
8318         x->recalc_luma_mc_data ^= 1;
8319       }
8320       swap_dst_buf(xd, dst_bufs, num_planes);
8321     }
8322     return 1;
8323   }
8324   mbmi->interp_filters = last_best;
8325   return 0;
8326 }
8327 
pred_dual_interp_filter_rd(MACROBLOCK * const x,const AV1_COMP * const cpi,const TileDataEnc * tile_data,BLOCK_SIZE bsize,int mi_row,int mi_col,const BUFFER_SET * const orig_dst,int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,const BUFFER_SET * dst_bufs[2],InterpFilters filter_idx,const int switchable_ctx[2],const int skip_pred,int * rate,int64_t * dist,InterpFilters af_horiz,InterpFilters af_vert,InterpFilters lf_horiz,InterpFilters lf_vert)8328 static INLINE void pred_dual_interp_filter_rd(
8329     MACROBLOCK *const x, const AV1_COMP *const cpi,
8330     const TileDataEnc *tile_data, BLOCK_SIZE bsize, int mi_row, int mi_col,
8331     const BUFFER_SET *const orig_dst, int64_t *const rd,
8332     int *const switchable_rate, int *const skip_txfm_sb,
8333     int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2],
8334     InterpFilters filter_idx, const int switchable_ctx[2], const int skip_pred,
8335     int *rate, int64_t *dist, InterpFilters af_horiz, InterpFilters af_vert,
8336     InterpFilters lf_horiz, InterpFilters lf_vert) {
8337   if ((af_horiz == lf_horiz) && (af_horiz != SWITCHABLE)) {
8338     if (((af_vert == lf_vert) && (af_vert != SWITCHABLE))) {
8339       filter_idx = af_horiz + (af_vert * SWITCHABLE_FILTERS);
8340       if (filter_idx) {
8341         interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
8342                                 orig_dst, rd, switchable_rate, skip_txfm_sb,
8343                                 skip_sse_sb, dst_bufs, filter_idx,
8344                                 switchable_ctx, skip_pred, rate, dist);
8345       }
8346     } else {
8347       for (filter_idx = af_horiz; filter_idx < (DUAL_FILTER_SET_SIZE);
8348            filter_idx += SWITCHABLE_FILTERS) {
8349         if (filter_idx) {
8350           interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
8351                                   orig_dst, rd, switchable_rate, skip_txfm_sb,
8352                                   skip_sse_sb, dst_bufs, filter_idx,
8353                                   switchable_ctx, skip_pred, rate, dist);
8354         }
8355       }
8356     }
8357   } else if ((af_vert == lf_vert) && (af_vert != SWITCHABLE)) {
8358     for (filter_idx = (af_vert * SWITCHABLE_FILTERS);
8359          filter_idx <= ((af_vert * SWITCHABLE_FILTERS) + 2); filter_idx += 1) {
8360       if (filter_idx) {
8361         interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
8362                                 orig_dst, rd, switchable_rate, skip_txfm_sb,
8363                                 skip_sse_sb, dst_bufs, filter_idx,
8364                                 switchable_ctx, skip_pred, rate, dist);
8365       }
8366     }
8367   }
8368 }
8369 
8370 // Find the best interp filter if dual_interp_filter = 0
find_best_non_dual_interp_filter(MACROBLOCK * const x,const AV1_COMP * const cpi,const TileDataEnc * tile_data,BLOCK_SIZE bsize,int mi_row,int mi_col,const BUFFER_SET * const orig_dst,int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,const BUFFER_SET * dst_bufs[2],const int switchable_ctx[2],const int skip_ver,const int skip_hor,int * rate,int64_t * dist,int filter_set_size)8371 static INLINE void find_best_non_dual_interp_filter(
8372     MACROBLOCK *const x, const AV1_COMP *const cpi,
8373     const TileDataEnc *tile_data, BLOCK_SIZE bsize, int mi_row, int mi_col,
8374     const BUFFER_SET *const orig_dst, int64_t *const rd,
8375     int *const switchable_rate, int *const skip_txfm_sb,
8376     int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2],
8377     const int switchable_ctx[2], const int skip_ver, const int skip_hor,
8378     int *rate, int64_t *dist, int filter_set_size) {
8379   int16_t i;
8380   MACROBLOCKD *const xd = &x->e_mbd;
8381   MB_MODE_INFO *const mbmi = xd->mi[0];
8382 
8383   // Regular filter evaluation should have been done and hence the same should
8384   // be the winner
8385   assert(x->e_mbd.mi[0]->interp_filters == filter_sets[0]);
8386   assert(filter_set_size == DUAL_FILTER_SET_SIZE);
8387   if ((skip_hor & skip_ver) != cpi->default_interp_skip_flags) {
8388     const AV1_COMMON *cm = &cpi->common;
8389     int bsl, pred_filter_search;
8390     InterpFilters af = SWITCHABLE, lf = SWITCHABLE, filter_idx = 0;
8391     const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
8392     const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
8393     bsl = mi_size_wide_log2[bsize];
8394     pred_filter_search =
8395         cpi->sf.cb_pred_filter_search
8396             ? (((mi_row + mi_col) >> bsl) +
8397                get_chessboard_index(cm->current_frame.frame_number)) &
8398                   0x1
8399             : 0;
8400     if (above_mbmi && is_inter_block(above_mbmi)) {
8401       af = above_mbmi->interp_filters;
8402     }
8403     if (left_mbmi && is_inter_block(left_mbmi)) {
8404       lf = left_mbmi->interp_filters;
8405     }
8406     pred_filter_search &= ((af == lf) && (af != SWITCHABLE));
8407     if (pred_filter_search) {
8408       filter_idx = SWITCHABLE * (af & 0xf);
8409       // This assert tells that (filter_x == filter_y) for non-dual filter case
8410       assert((filter_sets[filter_idx] & 0xffff) ==
8411              (filter_sets[filter_idx] >> 16));
8412       if (cpi->sf.adaptive_interp_filter_search &&
8413           (cpi->sf.interp_filter_search_mask & (1 << (filter_idx >> 2)))) {
8414         return;
8415       }
8416       if (filter_idx) {
8417         interpolation_filter_rd(
8418             x, cpi, tile_data, bsize, mi_row, mi_col, orig_dst, rd,
8419             switchable_rate, skip_txfm_sb, skip_sse_sb, dst_bufs, filter_idx,
8420             switchable_ctx, (skip_hor & skip_ver), rate, dist);
8421       }
8422       return;
8423     }
8424   }
8425   // Reuse regular filter's modeled rd data for sharp filter for following
8426   // cases
8427   // 1) When bsize is 4x4
8428   // 2) When block width is 4 (i.e. 4x8/4x16 blocks) and MV in vertical
8429   // direction is full-pel
8430   // 3) When block height is 4 (i.e. 8x4/16x4 blocks) and MV in horizontal
8431   // direction is full-pel
8432   // TODO(any): Optimize cases 2 and 3 further if luma MV in relavant direction
8433   // alone is full-pel
8434 
8435   if ((bsize == BLOCK_4X4) ||
8436       (block_size_wide[bsize] == 4 &&
8437        skip_ver == cpi->default_interp_skip_flags) ||
8438       (block_size_high[bsize] == 4 &&
8439        skip_hor == cpi->default_interp_skip_flags)) {
8440     int skip_pred = cpi->default_interp_skip_flags;
8441     for (i = filter_set_size - 1; i > 0; i -= (SWITCHABLE_FILTERS + 1)) {
8442       // This assert tells that (filter_x == filter_y) for non-dual filter case
8443       assert((filter_sets[i] & 0xffff) == (filter_sets[i] >> 16));
8444       if (cpi->sf.adaptive_interp_filter_search &&
8445           (cpi->sf.interp_filter_search_mask & (1 << (i >> 2)))) {
8446         continue;
8447       }
8448       interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
8449                               orig_dst, rd, switchable_rate, skip_txfm_sb,
8450                               skip_sse_sb, dst_bufs, i, switchable_ctx,
8451                               skip_pred, rate, dist);
8452       skip_pred = (skip_hor & skip_ver);
8453     }
8454   } else {
8455     int skip_pred = (skip_hor & skip_ver);
8456     for (i = (SWITCHABLE_FILTERS + 1); i < filter_set_size;
8457          i += (SWITCHABLE_FILTERS + 1)) {
8458       // This assert tells that (filter_x == filter_y) for non-dual filter case
8459       assert((filter_sets[i] & 0xffff) == (filter_sets[i] >> 16));
8460       if (cpi->sf.adaptive_interp_filter_search &&
8461           (cpi->sf.interp_filter_search_mask & (1 << (i >> 2)))) {
8462         continue;
8463       }
8464       interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
8465                               orig_dst, rd, switchable_rate, skip_txfm_sb,
8466                               skip_sse_sb, dst_bufs, i, switchable_ctx,
8467                               skip_pred, rate, dist);
8468       // In first iteration, smooth filter is evaluated. If smooth filter
8469       // (which is less sharper) is the winner among regular and smooth filters,
8470       // sharp filter evaluation is skipped
8471       // TODO(any): Refine this gating based on modelled rd only (i.e., by not
8472       // accounting switchable filter rate)
8473       if (cpi->sf.skip_sharp_interp_filter_search &&
8474           skip_pred != cpi->default_interp_skip_flags) {
8475         if (mbmi->interp_filters == filter_sets[(SWITCHABLE_FILTERS + 1)])
8476           break;
8477       }
8478     }
8479   }
8480 }
8481 
8482 // check if there is saved result match with this search
is_interp_filter_match(const INTERPOLATION_FILTER_STATS * st,MB_MODE_INFO * const mi)8483 static INLINE int is_interp_filter_match(const INTERPOLATION_FILTER_STATS *st,
8484                                          MB_MODE_INFO *const mi) {
8485   for (int i = 0; i < 2; ++i) {
8486     if ((st->ref_frames[i] != mi->ref_frame[i]) ||
8487         (st->mv[i].as_int != mi->mv[i].as_int)) {
8488       return 0;
8489     }
8490   }
8491   if (has_second_ref(mi) && st->comp_type != mi->interinter_comp.type) return 0;
8492   return 1;
8493 }
8494 
8495 // Checks if characteristics of search match
is_comp_rd_match(const AV1_COMP * const cpi,const MACROBLOCK * const x,const COMP_RD_STATS * st,const MB_MODE_INFO * const mi,int32_t * comp_rate,int64_t * comp_dist,int64_t * comp_model_rd)8496 static INLINE int is_comp_rd_match(const AV1_COMP *const cpi,
8497                                    const MACROBLOCK *const x,
8498                                    const COMP_RD_STATS *st,
8499                                    const MB_MODE_INFO *const mi,
8500                                    int32_t *comp_rate, int64_t *comp_dist,
8501                                    int64_t *comp_model_rd) {
8502   // TODO(ranjit): Ensure that compound type search use regular filter always
8503   // and check if following check can be removed
8504   // Check if interp filter matches with previous case
8505   if (st->filter != mi->interp_filters) return 0;
8506 
8507   const MACROBLOCKD *const xd = &x->e_mbd;
8508   // Match MV and reference indices
8509   for (int i = 0; i < 2; ++i) {
8510     if ((st->ref_frames[i] != mi->ref_frame[i]) ||
8511         (st->mv[i].as_int != mi->mv[i].as_int)) {
8512       return 0;
8513     }
8514     const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
8515     if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
8516   }
8517 
8518   // Store the stats for compound average
8519   comp_rate[COMPOUND_AVERAGE] = st->rate[COMPOUND_AVERAGE];
8520   comp_dist[COMPOUND_AVERAGE] = st->dist[COMPOUND_AVERAGE];
8521   comp_model_rd[COMPOUND_AVERAGE] = st->comp_model_rd[COMPOUND_AVERAGE];
8522   comp_rate[COMPOUND_DISTWTD] = st->rate[COMPOUND_DISTWTD];
8523   comp_dist[COMPOUND_DISTWTD] = st->dist[COMPOUND_DISTWTD];
8524   comp_model_rd[COMPOUND_DISTWTD] = st->comp_model_rd[COMPOUND_DISTWTD];
8525 
8526   // For compound wedge/segment, reuse data only if NEWMV is not present in
8527   // either of the directions
8528   if ((!have_newmv_in_inter_mode(mi->mode) &&
8529        !have_newmv_in_inter_mode(st->mode)) ||
8530       (cpi->sf.disable_interinter_wedge_newmv_search)) {
8531     memcpy(&comp_rate[COMPOUND_WEDGE], &st->rate[COMPOUND_WEDGE],
8532            sizeof(comp_rate[COMPOUND_WEDGE]) * 2);
8533     memcpy(&comp_dist[COMPOUND_WEDGE], &st->dist[COMPOUND_WEDGE],
8534            sizeof(comp_dist[COMPOUND_WEDGE]) * 2);
8535     memcpy(&comp_model_rd[COMPOUND_WEDGE], &st->comp_model_rd[COMPOUND_WEDGE],
8536            sizeof(comp_model_rd[COMPOUND_WEDGE]) * 2);
8537   }
8538   return 1;
8539 }
8540 
find_interp_filter_in_stats(MACROBLOCK * x,MB_MODE_INFO * const mbmi)8541 static INLINE int find_interp_filter_in_stats(MACROBLOCK *x,
8542                                               MB_MODE_INFO *const mbmi) {
8543   const int comp_idx = mbmi->compound_idx;
8544   const int offset = x->interp_filter_stats_idx[comp_idx];
8545   for (int j = 0; j < offset; ++j) {
8546     const INTERPOLATION_FILTER_STATS *st = &x->interp_filter_stats[comp_idx][j];
8547     if (is_interp_filter_match(st, mbmi)) {
8548       mbmi->interp_filters = st->filters;
8549       return j;
8550     }
8551   }
8552   return -1;  // no match result found
8553 }
8554 // Checks if similar compound type search case is accounted earlier
8555 // If found, returns relevant rd data
find_comp_rd_in_stats(const AV1_COMP * const cpi,const MACROBLOCK * x,const MB_MODE_INFO * const mbmi,int32_t * comp_rate,int64_t * comp_dist,int64_t * comp_model_rd)8556 static INLINE int find_comp_rd_in_stats(const AV1_COMP *const cpi,
8557                                         const MACROBLOCK *x,
8558                                         const MB_MODE_INFO *const mbmi,
8559                                         int32_t *comp_rate, int64_t *comp_dist,
8560                                         int64_t *comp_model_rd) {
8561   for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
8562     if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
8563                          comp_dist, comp_model_rd)) {
8564       return 1;
8565     }
8566   }
8567   return 0;  // no match result found
8568 }
8569 
save_interp_filter_search_stat(MACROBLOCK * x,MB_MODE_INFO * const mbmi,int64_t rd,int skip_txfm_sb,int64_t skip_sse_sb,unsigned int pred_sse)8570 static INLINE void save_interp_filter_search_stat(MACROBLOCK *x,
8571                                                   MB_MODE_INFO *const mbmi,
8572                                                   int64_t rd, int skip_txfm_sb,
8573                                                   int64_t skip_sse_sb,
8574                                                   unsigned int pred_sse) {
8575   const int comp_idx = mbmi->compound_idx;
8576   const int offset = x->interp_filter_stats_idx[comp_idx];
8577   if (offset < MAX_INTERP_FILTER_STATS) {
8578     INTERPOLATION_FILTER_STATS stat = { mbmi->interp_filters,
8579                                         { mbmi->mv[0], mbmi->mv[1] },
8580                                         { mbmi->ref_frame[0],
8581                                           mbmi->ref_frame[1] },
8582                                         mbmi->interinter_comp.type,
8583                                         rd,
8584                                         skip_txfm_sb,
8585                                         skip_sse_sb,
8586                                         pred_sse };
8587     x->interp_filter_stats[comp_idx][offset] = stat;
8588     x->interp_filter_stats_idx[comp_idx]++;
8589   }
8590 }
8591 
save_comp_rd_search_stat(MACROBLOCK * x,const MB_MODE_INFO * const mbmi,const int32_t * comp_rate,const int64_t * comp_dist,const int64_t * comp_model_rd,const int_mv * cur_mv)8592 static INLINE void save_comp_rd_search_stat(MACROBLOCK *x,
8593                                             const MB_MODE_INFO *const mbmi,
8594                                             const int32_t *comp_rate,
8595                                             const int64_t *comp_dist,
8596                                             const int64_t *comp_model_rd,
8597                                             const int_mv *cur_mv) {
8598   const int offset = x->comp_rd_stats_idx;
8599   if (offset < MAX_COMP_RD_STATS) {
8600     COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
8601     memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
8602     memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
8603     memcpy(rd_stats->comp_model_rd, comp_model_rd,
8604            sizeof(rd_stats->comp_model_rd));
8605     memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
8606     memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
8607     rd_stats->mode = mbmi->mode;
8608     rd_stats->filter = mbmi->interp_filters;
8609     rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
8610     const MACROBLOCKD *const xd = &x->e_mbd;
8611     for (int i = 0; i < 2; ++i) {
8612       const WarpedMotionParams *const wm =
8613           &xd->global_motion[mbmi->ref_frame[i]];
8614       rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
8615     }
8616     ++x->comp_rd_stats_idx;
8617   }
8618 }
8619 
interpolation_filter_search(MACROBLOCK * const x,const AV1_COMP * const cpi,const TileDataEnc * tile_data,BLOCK_SIZE bsize,int mi_row,int mi_col,const BUFFER_SET * const tmp_dst,const BUFFER_SET * const orig_dst,InterpFilter (* const single_filter)[REF_FRAMES],int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,int * skip_build_pred,HandleInterModeArgs * args,int64_t ref_best_rd)8620 static int64_t interpolation_filter_search(
8621     MACROBLOCK *const x, const AV1_COMP *const cpi,
8622     const TileDataEnc *tile_data, BLOCK_SIZE bsize, int mi_row, int mi_col,
8623     const BUFFER_SET *const tmp_dst, const BUFFER_SET *const orig_dst,
8624     InterpFilter (*const single_filter)[REF_FRAMES], int64_t *const rd,
8625     int *const switchable_rate, int *const skip_txfm_sb,
8626     int64_t *const skip_sse_sb, int *skip_build_pred, HandleInterModeArgs *args,
8627     int64_t ref_best_rd) {
8628   const AV1_COMMON *cm = &cpi->common;
8629   const int num_planes = av1_num_planes(cm);
8630   MACROBLOCKD *const xd = &x->e_mbd;
8631   MB_MODE_INFO *const mbmi = xd->mi[0];
8632   const int need_search =
8633       av1_is_interp_needed(xd) && av1_is_interp_search_needed(xd);
8634   int i;
8635   // Index 0 corresponds to luma rd data and index 1 corresponds to cummulative
8636   // data of all planes
8637   int tmp_rate[2] = { 0, 0 };
8638   int64_t tmp_dist[2] = { 0, 0 };
8639   int best_skip_txfm_sb[2] = { 1, 1 };
8640   int64_t best_skip_sse_sb[2] = { 0, 0 };
8641   const int ref_frame = xd->mi[0]->ref_frame[0];
8642 
8643   (void)single_filter;
8644   int match_found_idx = -1;
8645   const InterpFilter assign_filter = cm->interp_filter;
8646   if (cpi->sf.skip_repeat_interpolation_filter_search && need_search) {
8647     match_found_idx = find_interp_filter_in_stats(x, mbmi);
8648   }
8649   if (match_found_idx != -1) {
8650     const int comp_idx = mbmi->compound_idx;
8651     *rd = x->interp_filter_stats[comp_idx][match_found_idx].rd;
8652     *skip_txfm_sb =
8653         x->interp_filter_stats[comp_idx][match_found_idx].skip_txfm_sb;
8654     *skip_sse_sb =
8655         x->interp_filter_stats[comp_idx][match_found_idx].skip_sse_sb;
8656     x->pred_sse[ref_frame] =
8657         x->interp_filter_stats[comp_idx][match_found_idx].pred_sse;
8658     return 0;
8659   }
8660   if (!need_search || match_found_idx == -1) {
8661     set_default_interp_filters(mbmi, assign_filter);
8662   }
8663   int switchable_ctx[2];
8664   switchable_ctx[0] = av1_get_pred_context_switchable_interp(xd, 0);
8665   switchable_ctx[1] = av1_get_pred_context_switchable_interp(xd, 1);
8666   *switchable_rate =
8667       get_switchable_rate(x, mbmi->interp_filters, switchable_ctx);
8668   if (!(*skip_build_pred)) {
8669     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize, 0,
8670                                   av1_num_planes(cm) - 1);
8671     *skip_build_pred = 1;
8672   }
8673 
8674 #if CONFIG_COLLECT_RD_STATS == 3
8675   RD_STATS rd_stats_y;
8676   pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
8677   PrintPredictionUnitStats(cpi, tile_data, x, &rd_stats_y, bsize);
8678 #endif  // CONFIG_COLLECT_RD_STATS == 3
8679   model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
8680       cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &tmp_rate[0], &tmp_dist[0],
8681       &best_skip_txfm_sb[0], &best_skip_sse_sb[0], NULL, NULL, NULL);
8682   if (num_planes > 1)
8683     model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
8684         cpi, bsize, x, xd, 1, num_planes - 1, mi_row, mi_col, &tmp_rate[1],
8685         &tmp_dist[1], &best_skip_txfm_sb[1], &best_skip_sse_sb[1], NULL, NULL,
8686         NULL);
8687   tmp_rate[1] =
8688       (int)AOMMIN((int64_t)tmp_rate[0] + (int64_t)tmp_rate[1], INT_MAX);
8689   assert(tmp_rate[1] >= 0);
8690   tmp_dist[1] = tmp_dist[0] + tmp_dist[1];
8691   best_skip_txfm_sb[1] = best_skip_txfm_sb[0] & best_skip_txfm_sb[1];
8692   best_skip_sse_sb[1] = best_skip_sse_sb[0] + best_skip_sse_sb[1];
8693   *rd = RDCOST(x->rdmult, (*switchable_rate + tmp_rate[1]), tmp_dist[1]);
8694   *skip_txfm_sb = best_skip_txfm_sb[1];
8695   *skip_sse_sb = best_skip_sse_sb[1];
8696   x->pred_sse[ref_frame] = (unsigned int)(best_skip_sse_sb[0] >> 4);
8697 
8698   if (assign_filter != SWITCHABLE || match_found_idx != -1) {
8699     return 0;
8700   }
8701   if (!need_search) {
8702     assert(mbmi->interp_filters ==
8703            av1_broadcast_interp_filter(EIGHTTAP_REGULAR));
8704     return 0;
8705   }
8706   if (args->modelled_rd != NULL) {
8707     if (has_second_ref(mbmi)) {
8708       const int ref_mv_idx = mbmi->ref_mv_idx;
8709       int refs[2] = { mbmi->ref_frame[0],
8710                       (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
8711       const int mode0 = compound_ref0_mode(mbmi->mode);
8712       const int mode1 = compound_ref1_mode(mbmi->mode);
8713       const int64_t mrd = AOMMIN(args->modelled_rd[mode0][ref_mv_idx][refs[0]],
8714                                  args->modelled_rd[mode1][ref_mv_idx][refs[1]]);
8715       if ((*rd >> 1) > mrd && ref_best_rd < INT64_MAX) {
8716         return INT64_MAX;
8717       }
8718     }
8719   }
8720 
8721   x->recalc_luma_mc_data = 0;
8722   // skip_flag=xx (in binary form)
8723   // Setting 0th flag corresonds to skipping luma MC and setting 1st bt
8724   // corresponds to skipping chroma MC  skip_flag=0 corresponds to "Don't skip
8725   // luma and chroma MC"  Skip flag=1 corresponds to "Skip Luma MC only"
8726   // Skip_flag=2 is not a valid case
8727   // skip_flag=3 corresponds to "Skip both luma and chroma MC"
8728   int skip_hor = cpi->default_interp_skip_flags;
8729   int skip_ver = cpi->default_interp_skip_flags;
8730   const int is_compound = has_second_ref(mbmi);
8731   assert(is_intrabc_block(mbmi) == 0);
8732   for (int j = 0; j < 1 + is_compound; ++j) {
8733     const struct scale_factors *const sf =
8734         get_ref_scale_factors_const(cm, mbmi->ref_frame[j]);
8735     // TODO(any): Refine skip flag calculation considering scaling
8736     if (av1_is_scaled(sf)) {
8737       skip_hor = 0;
8738       skip_ver = 0;
8739       break;
8740     }
8741     const MV mv = mbmi->mv[j].as_mv;
8742     int skip_hor_plane = 0;
8743     int skip_ver_plane = 0;
8744     for (int k = 0; k < AOMMAX(1, (num_planes - 1)); ++k) {
8745       struct macroblockd_plane *const pd = &xd->plane[k];
8746       const int bw = pd->width;
8747       const int bh = pd->height;
8748       const MV mv_q4 = clamp_mv_to_umv_border_sb(
8749           xd, &mv, bw, bh, pd->subsampling_x, pd->subsampling_y);
8750       const int sub_x = (mv_q4.col & SUBPEL_MASK) << SCALE_EXTRA_BITS;
8751       const int sub_y = (mv_q4.row & SUBPEL_MASK) << SCALE_EXTRA_BITS;
8752       skip_hor_plane |= ((sub_x == 0) << k);
8753       skip_ver_plane |= ((sub_y == 0) << k);
8754     }
8755     skip_hor = skip_hor & skip_hor_plane;
8756     skip_ver = skip_ver & skip_ver_plane;
8757     // It is not valid that "luma MV is sub-pel, whereas chroma MV is not"
8758     assert(skip_hor != 2);
8759     assert(skip_ver != 2);
8760   }
8761   // When compond prediction type is compound segment wedge, luma MC and chroma
8762   // MC need to go hand in hand as mask generated during luma MC is reuired for
8763   // chroma MC. If skip_hor = 0 and skip_ver = 1, mask used for chroma MC during
8764   // vertical filter decision may be incorrect as temporary MC evaluation
8765   // overwrites the mask. Make skip_ver as 0 for this case so that mask is
8766   // populated during luma MC
8767   if (is_compound && mbmi->compound_idx == 1 &&
8768       mbmi->interinter_comp.type == COMPOUND_DIFFWTD) {
8769     assert(mbmi->comp_group_idx == 1);
8770     if (skip_hor == 0 && skip_ver == 1) skip_ver = 0;
8771   }
8772   // do interp_filter search
8773   const int filter_set_size = DUAL_FILTER_SET_SIZE;
8774   restore_dst_buf(xd, *tmp_dst, num_planes);
8775   const BUFFER_SET *dst_bufs[2] = { tmp_dst, orig_dst };
8776   if (cpi->sf.use_fast_interpolation_filter_search &&
8777       cm->seq_params.enable_dual_filter) {
8778     // default to (R,R): EIGHTTAP_REGULARxEIGHTTAP_REGULAR
8779     int best_dual_mode = 0;
8780     // Find best of {R}x{R,Sm,Sh}
8781     const int bw = block_size_wide[bsize];
8782     const int bh = block_size_high[bsize];
8783     int skip_pred;
8784     int bsl, pred_filter_search;
8785     InterpFilters af_horiz = SWITCHABLE, af_vert = SWITCHABLE,
8786                   lf_horiz = SWITCHABLE, lf_vert = SWITCHABLE, filter_idx = 0;
8787     const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
8788     const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
8789     bsl = mi_size_wide_log2[bsize];
8790     pred_filter_search =
8791         cpi->sf.cb_pred_filter_search
8792             ? (((mi_row + mi_col) >> bsl) +
8793                get_chessboard_index(cm->current_frame.frame_number)) &
8794                   0x1
8795             : 0;
8796     if (above_mbmi && is_inter_block(above_mbmi)) {
8797       af_horiz = av1_extract_interp_filter(above_mbmi->interp_filters, 1);
8798       af_vert = av1_extract_interp_filter(above_mbmi->interp_filters, 0);
8799     }
8800     if (left_mbmi && is_inter_block(left_mbmi)) {
8801       lf_horiz = av1_extract_interp_filter(left_mbmi->interp_filters, 1);
8802       lf_vert = av1_extract_interp_filter(left_mbmi->interp_filters, 0);
8803     }
8804     pred_filter_search &= !have_newmv_in_inter_mode(mbmi->mode);
8805     pred_filter_search &=
8806         ((af_horiz == lf_horiz) && (af_horiz != SWITCHABLE)) ||
8807         ((af_vert == lf_vert) && (af_vert != SWITCHABLE));
8808     if (pred_filter_search) {
8809       pred_dual_interp_filter_rd(
8810           x, cpi, tile_data, bsize, mi_row, mi_col, orig_dst, rd,
8811           switchable_rate, best_skip_txfm_sb, best_skip_sse_sb, dst_bufs,
8812           filter_idx, switchable_ctx, (skip_hor & skip_ver), tmp_rate, tmp_dist,
8813           af_horiz, af_vert, lf_horiz, lf_vert);
8814     } else {
8815       skip_pred = bw <= 4 ? cpi->default_interp_skip_flags : skip_hor;
8816       for (i = (SWITCHABLE_FILTERS - 1); i >= 1; --i) {
8817         if (interpolation_filter_rd(
8818                 x, cpi, tile_data, bsize, mi_row, mi_col, orig_dst, rd,
8819                 switchable_rate, best_skip_txfm_sb, best_skip_sse_sb, dst_bufs,
8820                 i, switchable_ctx, skip_pred, tmp_rate, tmp_dist)) {
8821           best_dual_mode = i;
8822         }
8823         skip_pred = skip_hor;
8824       }
8825       // From best of horizontal EIGHTTAP_REGULAR modes, check vertical modes
8826       skip_pred = bh <= 4 ? cpi->default_interp_skip_flags : skip_ver;
8827       assert(filter_set_size == DUAL_FILTER_SET_SIZE);
8828       for (i = (best_dual_mode + (SWITCHABLE_FILTERS * 2));
8829            i >= (best_dual_mode + SWITCHABLE_FILTERS);
8830            i -= SWITCHABLE_FILTERS) {
8831         interpolation_filter_rd(
8832             x, cpi, tile_data, bsize, mi_row, mi_col, orig_dst, rd,
8833             switchable_rate, best_skip_txfm_sb, best_skip_sse_sb, dst_bufs, i,
8834             switchable_ctx, skip_pred, tmp_rate, tmp_dist);
8835         skip_pred = skip_ver;
8836       }
8837     }
8838   } else if (cm->seq_params.enable_dual_filter == 0) {
8839     find_best_non_dual_interp_filter(
8840         x, cpi, tile_data, bsize, mi_row, mi_col, orig_dst, rd, switchable_rate,
8841         best_skip_txfm_sb, best_skip_sse_sb, dst_bufs, switchable_ctx, skip_ver,
8842         skip_hor, tmp_rate, tmp_dist, filter_set_size);
8843   } else {
8844     // EIGHTTAP_REGULAR mode is calculated beforehand
8845     for (i = 1; i < filter_set_size; ++i) {
8846       interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
8847                               orig_dst, rd, switchable_rate, best_skip_txfm_sb,
8848                               best_skip_sse_sb, dst_bufs, i, switchable_ctx,
8849                               (skip_hor & skip_ver), tmp_rate, tmp_dist);
8850     }
8851   }
8852   swap_dst_buf(xd, dst_bufs, num_planes);
8853   // Recompute final MC data if required
8854   if (x->recalc_luma_mc_data == 1) {
8855     // Recomputing final luma MC data is required only if the same was skipped
8856     // in either of the directions  Condition below is necessary, but not
8857     // sufficient
8858     assert((skip_hor == 1) || (skip_ver == 1));
8859     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
8860                                   AOM_PLANE_Y, AOM_PLANE_Y);
8861   }
8862   *skip_txfm_sb = best_skip_txfm_sb[1];
8863   *skip_sse_sb = best_skip_sse_sb[1];
8864   x->pred_sse[ref_frame] = (unsigned int)(best_skip_sse_sb[0] >> 4);
8865 
8866   // save search results
8867   if (cpi->sf.skip_repeat_interpolation_filter_search) {
8868     assert(match_found_idx == -1);
8869     save_interp_filter_search_stat(x, mbmi, *rd, *skip_txfm_sb, *skip_sse_sb,
8870                                    x->pred_sse[ref_frame]);
8871   }
8872   return 0;
8873 }
8874 
txfm_search(const AV1_COMP * cpi,const TileDataEnc * tile_data,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)8875 static int txfm_search(const AV1_COMP *cpi, const TileDataEnc *tile_data,
8876                        MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row, int mi_col,
8877                        RD_STATS *rd_stats, RD_STATS *rd_stats_y,
8878                        RD_STATS *rd_stats_uv, int mode_rate,
8879                        int64_t ref_best_rd) {
8880   /*
8881    * This function combines y and uv planes' transform search processes
8882    * together, when the prediction is generated. It first does subtraction to
8883    * obtain the prediction error. Then it calls
8884    * pick_tx_size_type_yrd/super_block_yrd and super_block_uvrd sequentially and
8885    * handles the early terminations happening in those functions. At the end, it
8886    * computes the rd_stats/_y/_uv accordingly.
8887    */
8888   const AV1_COMMON *cm = &cpi->common;
8889   MACROBLOCKD *const xd = &x->e_mbd;
8890   MB_MODE_INFO *const mbmi = xd->mi[0];
8891   const int ref_frame_1 = mbmi->ref_frame[1];
8892   const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
8893   const int64_t rd_thresh =
8894       ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
8895   const int skip_ctx = av1_get_skip_context(xd);
8896   const int skip_flag_cost[2] = { x->skip_cost[skip_ctx][0],
8897                                   x->skip_cost[skip_ctx][1] };
8898   const int64_t min_header_rate =
8899       mode_rate + AOMMIN(skip_flag_cost[0], skip_flag_cost[1]);
8900   // Account for minimum skip and non_skip rd.
8901   // Eventually either one of them will be added to mode_rate
8902   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
8903   (void)tile_data;
8904 
8905   if (min_header_rd_possible > ref_best_rd) {
8906     av1_invalid_rd_stats(rd_stats_y);
8907     return 0;
8908   }
8909 
8910   av1_init_rd_stats(rd_stats);
8911   av1_init_rd_stats(rd_stats_y);
8912   rd_stats->rate = mode_rate;
8913 
8914   // cost and distortion
8915   av1_subtract_plane(x, bsize, 0);
8916   if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
8917     pick_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, mi_row, mi_col, rd_thresh);
8918 #if CONFIG_COLLECT_RD_STATS == 2
8919     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
8920 #endif  // CONFIG_COLLECT_RD_STATS == 2
8921   } else {
8922     super_block_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
8923     memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
8924     for (int i = 0; i < xd->n4_h * xd->n4_w; ++i)
8925       set_blk_skip(x, 0, i, rd_stats_y->skip);
8926   }
8927 
8928   if (rd_stats_y->rate == INT_MAX) {
8929     // TODO(angiebird): check if we need this
8930     // restore_dst_buf(xd, *orig_dst, num_planes);
8931     mbmi->ref_frame[1] = ref_frame_1;
8932     return 0;
8933   }
8934 
8935   av1_merge_rd_stats(rd_stats, rd_stats_y);
8936 
8937   const int64_t non_skip_rdcosty =
8938       RDCOST(x->rdmult, rd_stats->rate + skip_flag_cost[0], rd_stats->dist);
8939   const int64_t skip_rdcosty =
8940       RDCOST(x->rdmult, mode_rate + skip_flag_cost[1], rd_stats->sse);
8941   const int64_t min_rdcosty = AOMMIN(non_skip_rdcosty, skip_rdcosty);
8942   if (min_rdcosty > ref_best_rd) {
8943     const int64_t tokenonly_rdy =
8944         AOMMIN(RDCOST(x->rdmult, rd_stats_y->rate, rd_stats_y->dist),
8945                RDCOST(x->rdmult, 0, rd_stats_y->sse));
8946     // Invalidate rd_stats_y to skip the rest of the motion modes search
8947     if (tokenonly_rdy - (tokenonly_rdy >> cpi->sf.prune_motion_mode_level) >
8948         rd_thresh)
8949       av1_invalid_rd_stats(rd_stats_y);
8950     mbmi->ref_frame[1] = ref_frame_1;
8951     return 0;
8952   }
8953 
8954   av1_init_rd_stats(rd_stats_uv);
8955   const int num_planes = av1_num_planes(cm);
8956   if (num_planes > 1) {
8957     int64_t ref_best_chroma_rd = ref_best_rd;
8958     // Calculate best rd cost possible for chroma
8959     if (cpi->sf.perform_best_rd_based_gating_for_chroma &&
8960         (ref_best_chroma_rd != INT64_MAX)) {
8961       ref_best_chroma_rd =
8962           (ref_best_chroma_rd - AOMMIN(non_skip_rdcosty, skip_rdcosty));
8963     }
8964     const int is_cost_valid_uv =
8965         super_block_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
8966     if (!is_cost_valid_uv) {
8967       mbmi->ref_frame[1] = ref_frame_1;
8968       return 0;
8969     }
8970     av1_merge_rd_stats(rd_stats, rd_stats_uv);
8971   }
8972 
8973   if (rd_stats->skip) {
8974     rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
8975     rd_stats_y->rate = 0;
8976     rd_stats_uv->rate = 0;
8977     rd_stats->dist = rd_stats->sse;
8978     rd_stats_y->dist = rd_stats_y->sse;
8979     rd_stats_uv->dist = rd_stats_uv->sse;
8980     rd_stats->rate += skip_flag_cost[1];
8981     mbmi->skip = 1;
8982     // here mbmi->skip temporarily plays a role as what this_skip2 does
8983 
8984     const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
8985     if (tmprd > ref_best_rd) {
8986       mbmi->ref_frame[1] = ref_frame_1;
8987       return 0;
8988     }
8989   } else if (!xd->lossless[mbmi->segment_id] &&
8990              (RDCOST(x->rdmult,
8991                      rd_stats_y->rate + rd_stats_uv->rate + skip_flag_cost[0],
8992                      rd_stats->dist) >=
8993               RDCOST(x->rdmult, skip_flag_cost[1], rd_stats->sse))) {
8994     rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
8995     rd_stats->rate += skip_flag_cost[1];
8996     rd_stats->dist = rd_stats->sse;
8997     rd_stats_y->dist = rd_stats_y->sse;
8998     rd_stats_uv->dist = rd_stats_uv->sse;
8999     rd_stats_y->rate = 0;
9000     rd_stats_uv->rate = 0;
9001     mbmi->skip = 1;
9002   } else {
9003     rd_stats->rate += skip_flag_cost[0];
9004     mbmi->skip = 0;
9005   }
9006 
9007   return 1;
9008 }
9009 
enable_wedge_search(MACROBLOCK * const x,const AV1_COMP * const cpi)9010 static INLINE bool enable_wedge_search(MACROBLOCK *const x,
9011                                        const AV1_COMP *const cpi) {
9012   // Enable wedge search if source variance and edge strength are above
9013   // the thresholds.
9014   return x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
9015          x->edge_strength > cpi->sf.disable_wedge_search_edge_thresh;
9016 }
9017 
enable_wedge_interinter_search(MACROBLOCK * const x,const AV1_COMP * const cpi)9018 static INLINE bool enable_wedge_interinter_search(MACROBLOCK *const x,
9019                                                   const AV1_COMP *const cpi) {
9020   return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interinter_wedge;
9021 }
9022 
enable_wedge_interintra_search(MACROBLOCK * const x,const AV1_COMP * const cpi)9023 static INLINE bool enable_wedge_interintra_search(MACROBLOCK *const x,
9024                                                   const AV1_COMP *const cpi) {
9025   return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interintra_wedge &&
9026          !cpi->sf.disable_wedge_interintra_search;
9027 }
9028 
handle_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mi_row,int mi_col,MB_MODE_INFO * mbmi,HandleInterModeArgs * args,int64_t ref_best_rd,int * rate_mv,int * tmp_rate2,const BUFFER_SET * orig_dst)9029 static int handle_inter_intra_mode(const AV1_COMP *const cpi,
9030                                    MACROBLOCK *const x, BLOCK_SIZE bsize,
9031                                    int mi_row, int mi_col, MB_MODE_INFO *mbmi,
9032                                    HandleInterModeArgs *args,
9033                                    int64_t ref_best_rd, int *rate_mv,
9034                                    int *tmp_rate2, const BUFFER_SET *orig_dst) {
9035   const AV1_COMMON *const cm = &cpi->common;
9036   const int num_planes = av1_num_planes(cm);
9037   MACROBLOCKD *xd = &x->e_mbd;
9038 
9039   INTERINTRA_MODE best_interintra_mode = II_DC_PRED;
9040   int64_t rd = INT64_MAX;
9041   int64_t best_interintra_rd = INT64_MAX;
9042   int rmode, rate_sum;
9043   int64_t dist_sum;
9044   int tmp_rate_mv = 0;
9045   int tmp_skip_txfm_sb;
9046   int bw = block_size_wide[bsize];
9047   int64_t tmp_skip_sse_sb;
9048   DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
9049   DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
9050   uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
9051   uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
9052   const int *const interintra_mode_cost =
9053       x->interintra_mode_cost[size_group_lookup[bsize]];
9054   const int_mv mv0 = mbmi->mv[0];
9055   const int is_wedge_used = is_interintra_wedge_used(bsize);
9056   int rwedge = is_wedge_used ? x->wedge_interintra_cost[bsize][0] : 0;
9057   mbmi->ref_frame[1] = NONE_FRAME;
9058   xd->plane[0].dst.buf = tmp_buf;
9059   xd->plane[0].dst.stride = bw;
9060   av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
9061                                 AOM_PLANE_Y, AOM_PLANE_Y);
9062 
9063   restore_dst_buf(xd, *orig_dst, num_planes);
9064   mbmi->ref_frame[1] = INTRA_FRAME;
9065   best_interintra_mode = args->inter_intra_mode[mbmi->ref_frame[0]];
9066 
9067   if (cpi->oxcf.enable_smooth_interintra &&
9068       !cpi->sf.disable_smooth_interintra) {
9069     mbmi->use_wedge_interintra = 0;
9070     int j = 0;
9071     if (cpi->sf.reuse_inter_intra_mode == 0 ||
9072         best_interintra_mode == INTERINTRA_MODES) {
9073       for (j = 0; j < INTERINTRA_MODES; ++j) {
9074         if ((!cpi->oxcf.enable_smooth_intra || cpi->sf.disable_smooth_intra) &&
9075             (INTERINTRA_MODE)j == II_SMOOTH_PRED)
9076           continue;
9077         mbmi->interintra_mode = (INTERINTRA_MODE)j;
9078         rmode = interintra_mode_cost[mbmi->interintra_mode];
9079         av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
9080                                                   intrapred, bw);
9081         av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
9082         model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](
9083             cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
9084             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
9085         rd = RDCOST(x->rdmult, tmp_rate_mv + rate_sum + rmode, dist_sum);
9086         if (rd < best_interintra_rd) {
9087           best_interintra_rd = rd;
9088           best_interintra_mode = mbmi->interintra_mode;
9089         }
9090       }
9091       args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
9092     }
9093     assert(IMPLIES(!cpi->oxcf.enable_smooth_interintra ||
9094                        cpi->sf.disable_smooth_interintra,
9095                    best_interintra_mode != II_SMOOTH_PRED));
9096     rmode = interintra_mode_cost[best_interintra_mode];
9097     if (j == 0 || best_interintra_mode != II_SMOOTH_PRED) {
9098       mbmi->interintra_mode = best_interintra_mode;
9099       av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
9100                                                 intrapred, bw);
9101       av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
9102     }
9103 
9104     RD_STATS rd_stats;
9105     rd = estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &rd_stats);
9106     if (rd != INT64_MAX) {
9107       rd = RDCOST(x->rdmult, *rate_mv + rmode + rd_stats.rate + rwedge,
9108                   rd_stats.dist);
9109     }
9110     best_interintra_rd = rd;
9111     if (ref_best_rd < INT64_MAX &&
9112         ((best_interintra_rd >> 4) * 9) > ref_best_rd) {
9113       return -1;
9114     }
9115   }
9116   if (is_wedge_used) {
9117     int64_t best_interintra_rd_nowedge = rd;
9118     int64_t best_interintra_rd_wedge = INT64_MAX;
9119     int_mv tmp_mv;
9120     if (enable_wedge_interintra_search(x, cpi)) {
9121       mbmi->use_wedge_interintra = 1;
9122 
9123       rwedge = av1_cost_literal(get_interintra_wedge_bits(bsize)) +
9124                x->wedge_interintra_cost[bsize][1];
9125 
9126       if (!cpi->oxcf.enable_smooth_interintra ||
9127           cpi->sf.disable_smooth_interintra) {
9128         if (best_interintra_mode == INTERINTRA_MODES) {
9129           mbmi->interintra_mode = II_SMOOTH_PRED;
9130           best_interintra_mode = II_SMOOTH_PRED;
9131           av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
9132                                                     intrapred, bw);
9133           best_interintra_rd_wedge =
9134               pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
9135 
9136           int j = 0;
9137           for (j = 0; j < INTERINTRA_MODES; ++j) {
9138             mbmi->interintra_mode = (INTERINTRA_MODE)j;
9139             rmode = interintra_mode_cost[mbmi->interintra_mode];
9140             av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0,
9141                                                       orig_dst, intrapred, bw);
9142             av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
9143             model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](
9144                 cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
9145                 &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
9146             rd = RDCOST(x->rdmult, tmp_rate_mv + rate_sum + rmode, dist_sum);
9147             if (rd < best_interintra_rd) {
9148               best_interintra_rd_wedge = rd;
9149               best_interintra_mode = mbmi->interintra_mode;
9150             }
9151           }
9152           args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
9153           mbmi->interintra_mode = best_interintra_mode;
9154 
9155           if (best_interintra_mode != II_SMOOTH_PRED) {
9156             av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0,
9157                                                       orig_dst, intrapred, bw);
9158           }
9159         } else {
9160           mbmi->interintra_mode = best_interintra_mode;
9161           av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
9162                                                     intrapred, bw);
9163           best_interintra_rd_wedge =
9164               pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
9165         }
9166       } else {
9167         best_interintra_rd_wedge =
9168             pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
9169       }
9170 
9171       rmode = interintra_mode_cost[mbmi->interintra_mode];
9172       best_interintra_rd_wedge +=
9173           RDCOST(x->rdmult, rmode + *rate_mv + rwedge, 0);
9174       rd = INT64_MAX;
9175       // Refine motion vector.
9176       if (have_newmv_in_inter_mode(mbmi->mode)) {
9177         // get negative of mask
9178         const uint8_t *mask = av1_get_contiguous_soft_mask(
9179             mbmi->interintra_wedge_index, 1, bsize);
9180         tmp_mv = mbmi->mv[0];
9181         compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, mi_row,
9182                                       mi_col, intrapred, mask, bw, &tmp_rate_mv,
9183                                       0);
9184         if (mbmi->mv[0].as_int != tmp_mv.as_int) {
9185           mbmi->mv[0].as_int = tmp_mv.as_int;
9186           av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
9187                                         AOM_PLANE_Y, AOM_PLANE_Y);
9188           model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
9189               cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
9190               &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
9191           rd = RDCOST(x->rdmult, tmp_rate_mv + rmode + rate_sum + rwedge,
9192                       dist_sum);
9193         }
9194       }
9195       if (rd >= best_interintra_rd_wedge) {
9196         tmp_mv.as_int = mv0.as_int;
9197         tmp_rate_mv = *rate_mv;
9198         av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
9199       }
9200       // Evaluate closer to true rd
9201       RD_STATS rd_stats;
9202       rd = estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &rd_stats);
9203       if (rd != INT64_MAX) {
9204         rd = RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge + rd_stats.rate,
9205                     rd_stats.dist);
9206       }
9207       best_interintra_rd_wedge = rd;
9208       if ((!cpi->oxcf.enable_smooth_interintra ||
9209            cpi->sf.disable_smooth_interintra) &&
9210           best_interintra_rd_wedge == INT64_MAX)
9211         return -1;
9212       if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
9213         mbmi->use_wedge_interintra = 1;
9214         mbmi->mv[0].as_int = tmp_mv.as_int;
9215         *tmp_rate2 += tmp_rate_mv - *rate_mv;
9216         *rate_mv = tmp_rate_mv;
9217       } else {
9218         mbmi->use_wedge_interintra = 0;
9219         mbmi->mv[0].as_int = mv0.as_int;
9220         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
9221                                       AOM_PLANE_Y, AOM_PLANE_Y);
9222       }
9223     } else {
9224       if (!cpi->oxcf.enable_smooth_interintra ||
9225           cpi->sf.disable_smooth_interintra)
9226         return -1;
9227       mbmi->use_wedge_interintra = 0;
9228     }
9229   } else {
9230     if (best_interintra_rd == INT64_MAX) return -1;
9231   }
9232   if (num_planes > 1) {
9233     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
9234                                   AOM_PLANE_U, num_planes - 1);
9235   }
9236   return 0;
9237 }
9238 
9239 // If number of valid neighbours is 1,
9240 // 1) ROTZOOM parameters can be obtained reliably (2 parameters from
9241 // one neighbouring MV)
9242 // 2) For IDENTITY/TRANSLATION cases, warp can perform better due to
9243 // a different interpolation filter being used. However the quality
9244 // gains (due to the same) may not be much
9245 // For above 2 cases warp evaluation is skipped
9246 
check_if_optimal_warp(const AV1_COMP * cpi,WarpedMotionParams * wm_params,int num_proj_ref)9247 static int check_if_optimal_warp(const AV1_COMP *cpi,
9248                                  WarpedMotionParams *wm_params,
9249                                  int num_proj_ref) {
9250   int is_valid_warp = 1;
9251   if (cpi->sf.prune_warp_using_wmtype) {
9252     TransformationType wmtype = get_wmtype(wm_params);
9253     if (num_proj_ref == 1) {
9254       if (wmtype != ROTZOOM) is_valid_warp = 0;
9255     } else {
9256       if (wmtype < ROTZOOM) is_valid_warp = 0;
9257     }
9258   }
9259   return is_valid_warp;
9260 }
9261 
9262 struct obmc_check_mv_field_ctxt {
9263   MB_MODE_INFO *current_mi;
9264   int mv_field_check_result;
9265 };
9266 
obmc_check_identical_mv(MACROBLOCKD * xd,int rel_mi_col,uint8_t nb_mi_width,MB_MODE_INFO * nb_mi,void * fun_ctxt,const int num_planes)9267 static INLINE void obmc_check_identical_mv(MACROBLOCKD *xd, int rel_mi_col,
9268                                            uint8_t nb_mi_width,
9269                                            MB_MODE_INFO *nb_mi, void *fun_ctxt,
9270                                            const int num_planes) {
9271   (void)xd;
9272   (void)rel_mi_col;
9273   (void)nb_mi_width;
9274   (void)num_planes;
9275   struct obmc_check_mv_field_ctxt *ctxt =
9276       (struct obmc_check_mv_field_ctxt *)fun_ctxt;
9277   const MB_MODE_INFO *current_mi = ctxt->current_mi;
9278 
9279   if (ctxt->mv_field_check_result == 0) return;
9280 
9281   if (nb_mi->ref_frame[0] != current_mi->ref_frame[0] ||
9282       nb_mi->mv[0].as_int != current_mi->mv[0].as_int ||
9283       nb_mi->interp_filters != current_mi->interp_filters) {
9284     ctxt->mv_field_check_result = 0;
9285   }
9286 }
9287 
9288 // Check if the neighbors' motions used by obmc have same parameters as for
9289 // the current block. If all the parameters are identical, obmc will produce
9290 // the same prediction as from regular bmc, therefore we can skip the
9291 // overlapping operations for less complexity. The parameters checked include
9292 // reference frame, motion vector, and interpolation filter.
check_identical_obmc_mv_field(const AV1_COMMON * cm,MACROBLOCKD * xd,int mi_row,int mi_col)9293 int check_identical_obmc_mv_field(const AV1_COMMON *cm, MACROBLOCKD *xd,
9294                                   int mi_row, int mi_col) {
9295   const BLOCK_SIZE bsize = xd->mi[0]->sb_type;
9296   struct obmc_check_mv_field_ctxt mv_field_check_ctxt = { xd->mi[0], 1 };
9297 
9298   foreach_overlappable_nb_above(cm, xd, mi_col,
9299                                 max_neighbor_obmc[mi_size_wide_log2[bsize]],
9300                                 obmc_check_identical_mv, &mv_field_check_ctxt);
9301   foreach_overlappable_nb_left(cm, xd, mi_row,
9302                                max_neighbor_obmc[mi_size_high_log2[bsize]],
9303                                obmc_check_identical_mv, &mv_field_check_ctxt);
9304 
9305   return mv_field_check_ctxt.mv_field_check_result;
9306 }
9307 
skip_interintra_based_on_first_pass_stats(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mi_row,int mi_col)9308 static int skip_interintra_based_on_first_pass_stats(const AV1_COMP *const cpi,
9309                                                      MACROBLOCK *const x,
9310                                                      BLOCK_SIZE bsize,
9311                                                      int mi_row, int mi_col) {
9312   MACROBLOCKD *xd = &x->e_mbd;
9313   MB_MODE_INFO *mbmi = xd->mi[0];
9314   if (cpi->two_pass_partition_search &&
9315       cpi->sf.use_first_partition_pass_interintra_stats &&
9316       !x->cb_partition_scan) {
9317     const int mi_width = mi_size_wide[bsize];
9318     const int mi_height = mi_size_high[bsize];
9319     // Search in the stats table to see if obmc motion mode was used in the
9320     // first pass of partition search.
9321     for (int row = mi_row; row < mi_row + mi_width;
9322          row += FIRST_PARTITION_PASS_SAMPLE_REGION) {
9323       for (int col = mi_col; col < mi_col + mi_height;
9324            col += FIRST_PARTITION_PASS_SAMPLE_REGION) {
9325         const int index = av1_first_partition_pass_stats_index(row, col);
9326         const FIRST_PARTITION_PASS_STATS *const stats =
9327             &x->first_partition_pass_stats[index];
9328         if (stats->interintra_motion_mode_count[mbmi->ref_frame[0]]) {
9329           return 0;
9330         }
9331       }
9332     }
9333     return 1;
9334   }
9335   return 0;
9336 }
9337 
9338 // TODO(afergs): Refactor the MBMI references in here - there's four
9339 // TODO(afergs): Refactor optional args - add them to a struct or remove
motion_mode_rd(const AV1_COMP * const cpi,TileDataEnc * tile_data,MACROBLOCK * const x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int * disable_skip,int mi_row,int mi_col,HandleInterModeArgs * const args,int64_t ref_best_rd,const int * refs,int * rate_mv,const BUFFER_SET * orig_dst,int64_t * best_est_rd,int do_tx_search,InterModesInfo * inter_modes_info)9340 static int64_t motion_mode_rd(
9341     const AV1_COMP *const cpi, TileDataEnc *tile_data, MACROBLOCK *const x,
9342     BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y,
9343     RD_STATS *rd_stats_uv, int *disable_skip, int mi_row, int mi_col,
9344     HandleInterModeArgs *const args, int64_t ref_best_rd, const int *refs,
9345     int *rate_mv, const BUFFER_SET *orig_dst, int64_t *best_est_rd,
9346     int do_tx_search, InterModesInfo *inter_modes_info) {
9347   const AV1_COMMON *const cm = &cpi->common;
9348   const int num_planes = av1_num_planes(cm);
9349   MACROBLOCKD *xd = &x->e_mbd;
9350   MB_MODE_INFO *mbmi = xd->mi[0];
9351   const int is_comp_pred = has_second_ref(mbmi);
9352   const PREDICTION_MODE this_mode = mbmi->mode;
9353   const int rate2_nocoeff = rd_stats->rate;
9354   int best_xskip = 0, best_disable_skip = 0;
9355   RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv;
9356   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
9357   const int rate_mv0 = *rate_mv;
9358   int skip_interintra_mode = 0;
9359   const int interintra_allowed = cm->seq_params.enable_interintra_compound &&
9360                                  is_interintra_allowed(mbmi) &&
9361                                  mbmi->compound_idx;
9362   int pts0[SAMPLES_ARRAY_SIZE], pts_inref0[SAMPLES_ARRAY_SIZE];
9363 
9364   assert(mbmi->ref_frame[1] != INTRA_FRAME);
9365   const MV_REFERENCE_FRAME ref_frame_1 = mbmi->ref_frame[1];
9366   (void)tile_data;
9367   av1_invalid_rd_stats(&best_rd_stats);
9368   aom_clear_system_state();
9369   mbmi->num_proj_ref = 1;  // assume num_proj_ref >=1
9370   MOTION_MODE last_motion_mode_allowed = SIMPLE_TRANSLATION;
9371   if (cm->switchable_motion_mode) {
9372     last_motion_mode_allowed = motion_mode_allowed(xd->global_motion, xd, mbmi,
9373                                                    cm->allow_warped_motion);
9374   }
9375   if (last_motion_mode_allowed == WARPED_CAUSAL) {
9376     mbmi->num_proj_ref = findSamples(cm, xd, mi_row, mi_col, pts0, pts_inref0);
9377   }
9378   const int total_samples = mbmi->num_proj_ref;
9379   if (total_samples == 0) {
9380     last_motion_mode_allowed = OBMC_CAUSAL;
9381   }
9382 
9383   const MB_MODE_INFO base_mbmi = *mbmi;
9384   MB_MODE_INFO best_mbmi;
9385   SimpleRDState *const simple_states = &args->simple_rd_state[mbmi->ref_mv_idx];
9386   const int switchable_rate =
9387       av1_is_interp_needed(xd) ? av1_get_switchable_rate(cm, x, xd) : 0;
9388   int64_t best_rd = INT64_MAX;
9389   int best_rate_mv = rate_mv0;
9390   const int identical_obmc_mv_field_detected =
9391       (cpi->sf.skip_obmc_in_uniform_mv_field ||
9392        cpi->sf.skip_wm_in_uniform_mv_field)
9393           ? check_identical_obmc_mv_field(cm, xd, mi_row, mi_col)
9394           : 0;
9395   for (int mode_index = (int)SIMPLE_TRANSLATION;
9396        mode_index <= (int)last_motion_mode_allowed + interintra_allowed;
9397        mode_index++) {
9398     if (args->skip_motion_mode && mode_index) continue;
9399     if (cpi->sf.prune_single_motion_modes_by_simple_trans &&
9400         args->single_ref_first_pass && mode_index)
9401       break;
9402     int tmp_rate2 = rate2_nocoeff;
9403     const int is_interintra_mode = mode_index > (int)last_motion_mode_allowed;
9404     int tmp_rate_mv = rate_mv0;
9405 
9406     *mbmi = base_mbmi;
9407     if (is_interintra_mode) {
9408       mbmi->motion_mode = SIMPLE_TRANSLATION;
9409     } else {
9410       mbmi->motion_mode = (MOTION_MODE)mode_index;
9411       assert(mbmi->ref_frame[1] != INTRA_FRAME);
9412     }
9413 
9414     if (cpi->oxcf.enable_obmc == 0 && mbmi->motion_mode == OBMC_CAUSAL)
9415       continue;
9416 
9417     if (identical_obmc_mv_field_detected) {
9418       if (cpi->sf.skip_obmc_in_uniform_mv_field &&
9419           mbmi->motion_mode == OBMC_CAUSAL)
9420         continue;
9421       if (cpi->sf.skip_wm_in_uniform_mv_field &&
9422           mbmi->motion_mode == WARPED_CAUSAL)
9423         continue;
9424     }
9425 
9426     if (mbmi->motion_mode == SIMPLE_TRANSLATION && !is_interintra_mode) {
9427       // SIMPLE_TRANSLATION mode: no need to recalculate.
9428       // The prediction is calculated before motion_mode_rd() is called in
9429       // handle_inter_mode()
9430       if (cpi->sf.prune_single_motion_modes_by_simple_trans && !is_comp_pred) {
9431         if (args->single_ref_first_pass == 0) {
9432           if (simple_states->early_skipped) {
9433             assert(simple_states->rd_stats.rdcost == INT64_MAX);
9434             return INT64_MAX;
9435           }
9436           if (simple_states->rd_stats.rdcost != INT64_MAX) {
9437             best_rd = simple_states->rd_stats.rdcost;
9438             best_rd_stats = simple_states->rd_stats;
9439             best_rd_stats_y = simple_states->rd_stats_y;
9440             best_rd_stats_uv = simple_states->rd_stats_uv;
9441             memcpy(best_blk_skip, simple_states->blk_skip,
9442                    sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
9443             best_xskip = simple_states->skip;
9444             best_disable_skip = simple_states->disable_skip;
9445             best_mbmi = *mbmi;
9446           }
9447           continue;
9448         }
9449         simple_states->early_skipped = 0;
9450       }
9451     } else if (mbmi->motion_mode == OBMC_CAUSAL) {
9452       const uint32_t cur_mv = mbmi->mv[0].as_int;
9453       assert(!is_comp_pred);
9454       if (have_newmv_in_inter_mode(this_mode)) {
9455         single_motion_search(cpi, x, bsize, mi_row, mi_col, 0, &tmp_rate_mv);
9456         mbmi->mv[0].as_int = x->best_mv.as_int;
9457 #if USE_DISCOUNT_NEWMV_TEST
9458         if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) {
9459           tmp_rate_mv = AOMMAX((tmp_rate_mv / NEW_MV_DISCOUNT_FACTOR), 1);
9460         }
9461 #endif
9462         tmp_rate2 = rate2_nocoeff - rate_mv0 + tmp_rate_mv;
9463       }
9464       if (mbmi->mv[0].as_int != cur_mv) {
9465         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
9466                                       0, av1_num_planes(cm) - 1);
9467       }
9468       av1_build_obmc_inter_prediction(
9469           cm, xd, mi_row, mi_col, args->above_pred_buf, args->above_pred_stride,
9470           args->left_pred_buf, args->left_pred_stride);
9471     } else if (mbmi->motion_mode == WARPED_CAUSAL) {
9472       int pts[SAMPLES_ARRAY_SIZE], pts_inref[SAMPLES_ARRAY_SIZE];
9473       mbmi->motion_mode = WARPED_CAUSAL;
9474       mbmi->wm_params.wmtype = DEFAULT_WMTYPE;
9475       mbmi->interp_filters = av1_broadcast_interp_filter(
9476           av1_unswitchable_filter(cm->interp_filter));
9477 
9478       memcpy(pts, pts0, total_samples * 2 * sizeof(*pts0));
9479       memcpy(pts_inref, pts_inref0, total_samples * 2 * sizeof(*pts_inref0));
9480       // Select the samples according to motion vector difference
9481       if (mbmi->num_proj_ref > 1) {
9482         mbmi->num_proj_ref = selectSamples(&mbmi->mv[0].as_mv, pts, pts_inref,
9483                                            mbmi->num_proj_ref, bsize);
9484       }
9485 
9486       if (!find_projection(mbmi->num_proj_ref, pts, pts_inref, bsize,
9487                            mbmi->mv[0].as_mv.row, mbmi->mv[0].as_mv.col,
9488                            &mbmi->wm_params, mi_row, mi_col)) {
9489         // Refine MV for NEWMV mode
9490         assert(!is_comp_pred);
9491         if (have_newmv_in_inter_mode(this_mode)) {
9492           const int_mv mv0 = mbmi->mv[0];
9493           const WarpedMotionParams wm_params0 = mbmi->wm_params;
9494           const int num_proj_ref0 = mbmi->num_proj_ref;
9495 
9496           if (cpi->sf.prune_warp_using_wmtype) {
9497             TransformationType wmtype = get_wmtype(&mbmi->wm_params);
9498             if (wmtype < ROTZOOM) continue;
9499           }
9500 
9501           // Refine MV in a small range.
9502           av1_refine_warped_mv(cpi, x, bsize, mi_row, mi_col, pts0, pts_inref0,
9503                                total_samples);
9504 
9505           // Keep the refined MV and WM parameters.
9506           if (mv0.as_int != mbmi->mv[0].as_int) {
9507             const int ref = refs[0];
9508             const int_mv ref_mv = av1_get_ref_mv(x, 0);
9509             tmp_rate_mv = av1_mv_bit_cost(&mbmi->mv[0].as_mv, &ref_mv.as_mv,
9510                                           x->nmv_vec_cost, x->mv_cost_stack,
9511                                           MV_COST_WEIGHT);
9512 
9513             if (cpi->sf.adaptive_motion_search)
9514               x->pred_mv[ref] = mbmi->mv[0].as_mv;
9515 
9516 #if USE_DISCOUNT_NEWMV_TEST
9517             if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) {
9518               tmp_rate_mv = AOMMAX((tmp_rate_mv / NEW_MV_DISCOUNT_FACTOR), 1);
9519             }
9520 #endif
9521             tmp_rate2 = rate2_nocoeff - rate_mv0 + tmp_rate_mv;
9522           } else {
9523             // Restore the old MV and WM parameters.
9524             mbmi->mv[0] = mv0;
9525             mbmi->wm_params = wm_params0;
9526             mbmi->num_proj_ref = num_proj_ref0;
9527           }
9528         } else {
9529           if (!check_if_optimal_warp(cpi, &mbmi->wm_params, mbmi->num_proj_ref))
9530             continue;
9531         }
9532 
9533         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize, 0,
9534                                       av1_num_planes(cm) - 1);
9535       } else {
9536         continue;
9537       }
9538     } else if (is_interintra_mode) {
9539       skip_interintra_mode = skip_interintra_based_on_first_pass_stats(
9540           cpi, x, bsize, mi_row, mi_col);
9541       if (skip_interintra_mode) continue;
9542       const int ret = handle_inter_intra_mode(
9543           cpi, x, bsize, mi_row, mi_col, mbmi, args, ref_best_rd, &tmp_rate_mv,
9544           &tmp_rate2, orig_dst);
9545       if (ret < 0) continue;
9546     }
9547 
9548     x->skip = 0;
9549     rd_stats->dist = 0;
9550     rd_stats->sse = 0;
9551     rd_stats->skip = 1;
9552     rd_stats->rate = tmp_rate2;
9553     if (mbmi->motion_mode != WARPED_CAUSAL) rd_stats->rate += switchable_rate;
9554     if (interintra_allowed) {
9555       rd_stats->rate += x->interintra_cost[size_group_lookup[bsize]]
9556                                           [mbmi->ref_frame[1] == INTRA_FRAME];
9557       if (mbmi->ref_frame[1] == INTRA_FRAME) {
9558         rd_stats->rate += x->interintra_mode_cost[size_group_lookup[bsize]]
9559                                                  [mbmi->interintra_mode];
9560         if (is_interintra_wedge_used(bsize)) {
9561           rd_stats->rate +=
9562               x->wedge_interintra_cost[bsize][mbmi->use_wedge_interintra];
9563           if (mbmi->use_wedge_interintra) {
9564             rd_stats->rate +=
9565                 av1_cost_literal(get_interintra_wedge_bits(bsize));
9566           }
9567         }
9568       }
9569     }
9570     if ((last_motion_mode_allowed > SIMPLE_TRANSLATION) &&
9571         (mbmi->ref_frame[1] != INTRA_FRAME)) {
9572       if (last_motion_mode_allowed == WARPED_CAUSAL) {
9573         rd_stats->rate += x->motion_mode_cost[bsize][mbmi->motion_mode];
9574       } else {
9575         rd_stats->rate += x->motion_mode_cost1[bsize][mbmi->motion_mode];
9576       }
9577     }
9578 
9579     if (cpi->sf.model_based_motion_mode_rd_breakout && do_tx_search) {
9580       int model_rate;
9581       int64_t model_dist;
9582       model_rd_sb_fn[MODELRD_TYPE_MOTION_MODE_RD](
9583           cpi, mbmi->sb_type, x, xd, 0, num_planes - 1, mi_row, mi_col,
9584           &model_rate, &model_dist, NULL, NULL, NULL, NULL, NULL);
9585       const int64_t est_rd =
9586           RDCOST(x->rdmult, rd_stats->rate + model_rate, model_dist);
9587       if ((est_rd >> 3) * 6 > ref_best_rd) {
9588         mbmi->ref_frame[1] = ref_frame_1;
9589         continue;
9590       }
9591     }
9592 
9593     if (!do_tx_search) {
9594       int64_t curr_sse = -1;
9595       int est_residue_cost = 0;
9596       int64_t est_dist = 0;
9597       int64_t est_rd = 0;
9598       if (cpi->sf.inter_mode_rd_model_estimation == 1) {
9599         curr_sse = get_sse(cpi, x);
9600         const int has_est_rd = get_est_rate_dist(tile_data, bsize, curr_sse,
9601                                                  &est_residue_cost, &est_dist);
9602         (void)has_est_rd;
9603         assert(has_est_rd);
9604       } else if (cpi->sf.inter_mode_rd_model_estimation == 2 ||
9605                  cpi->sf.use_nonrd_pick_mode) {
9606         model_rd_sb_fn[MODELRD_TYPE_MOTION_MODE_RD](
9607             cpi, bsize, x, xd, 0, num_planes - 1, mi_row, mi_col,
9608             &est_residue_cost, &est_dist, NULL, &curr_sse, NULL, NULL, NULL);
9609       }
9610       est_rd = RDCOST(x->rdmult, rd_stats->rate + est_residue_cost, est_dist);
9611       if (est_rd * 0.8 > *best_est_rd) {
9612         mbmi->ref_frame[1] = ref_frame_1;
9613         continue;
9614       }
9615       const int mode_rate = rd_stats->rate;
9616       rd_stats->rate += est_residue_cost;
9617       rd_stats->dist = est_dist;
9618       rd_stats->rdcost = est_rd;
9619       *best_est_rd = AOMMIN(*best_est_rd, rd_stats->rdcost);
9620       if (cm->current_frame.reference_mode == SINGLE_REFERENCE) {
9621         if (!is_comp_pred) {
9622           assert(curr_sse >= 0);
9623           inter_modes_info_push(inter_modes_info, mode_rate, curr_sse,
9624                                 rd_stats->rdcost, false, NULL, rd_stats,
9625                                 rd_stats_y, rd_stats_uv, mbmi);
9626         }
9627       } else {
9628         assert(curr_sse >= 0);
9629         inter_modes_info_push(inter_modes_info, mode_rate, curr_sse,
9630                               rd_stats->rdcost, false, NULL, rd_stats,
9631                               rd_stats_y, rd_stats_uv, mbmi);
9632       }
9633     } else {
9634       if (!txfm_search(cpi, tile_data, x, bsize, mi_row, mi_col, rd_stats,
9635                        rd_stats_y, rd_stats_uv, rd_stats->rate, ref_best_rd)) {
9636         if (rd_stats_y->rate == INT_MAX && mode_index == 0) {
9637           if (cpi->sf.prune_single_motion_modes_by_simple_trans &&
9638               !is_comp_pred) {
9639             simple_states->early_skipped = 1;
9640           }
9641           return INT64_MAX;
9642         }
9643         continue;
9644       }
9645 
9646       const int64_t curr_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
9647       ref_best_rd = AOMMIN(ref_best_rd, curr_rd);
9648       *disable_skip = 0;
9649       if (cpi->sf.inter_mode_rd_model_estimation == 1) {
9650         const int skip_ctx = av1_get_skip_context(xd);
9651         inter_mode_data_push(tile_data, mbmi->sb_type, rd_stats->sse,
9652                              rd_stats->dist,
9653                              rd_stats_y->rate + rd_stats_uv->rate +
9654                                  x->skip_cost[skip_ctx][mbmi->skip]);
9655       }
9656 
9657       // 2 means to both do the tx search and also update the inter_modes_info
9658       // structure, since some modes will be conditionally TX searched.
9659       if (do_tx_search == 2) {
9660         rd_stats->rdcost = curr_rd;
9661         inter_modes_info_push(inter_modes_info, rd_stats->rate, rd_stats->sse,
9662                               curr_rd, true, x->blk_skip, rd_stats, rd_stats_y,
9663                               rd_stats_uv, mbmi);
9664       }
9665     }
9666 
9667     if (this_mode == GLOBALMV || this_mode == GLOBAL_GLOBALMV) {
9668       if (is_nontrans_global_motion(xd, xd->mi[0])) {
9669         mbmi->interp_filters = av1_broadcast_interp_filter(
9670             av1_unswitchable_filter(cm->interp_filter));
9671       }
9672     }
9673 
9674     const int64_t tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
9675     if (mode_index == 0) {
9676       args->simple_rd[this_mode][mbmi->ref_mv_idx][mbmi->ref_frame[0]] = tmp_rd;
9677       if (!is_comp_pred) {
9678         simple_states->rd_stats = *rd_stats;
9679         simple_states->rd_stats.rdcost = tmp_rd;
9680         simple_states->rd_stats_y = *rd_stats_y;
9681         simple_states->rd_stats_uv = *rd_stats_uv;
9682         memcpy(simple_states->blk_skip, x->blk_skip,
9683                sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
9684         simple_states->skip = x->skip;
9685         simple_states->disable_skip = *disable_skip;
9686       }
9687     }
9688     if (mode_index == 0 || tmp_rd < best_rd) {
9689       best_mbmi = *mbmi;
9690       best_rd = tmp_rd;
9691       best_rd_stats = *rd_stats;
9692       best_rd_stats_y = *rd_stats_y;
9693       best_rate_mv = tmp_rate_mv;
9694       if (num_planes > 1) best_rd_stats_uv = *rd_stats_uv;
9695       memcpy(best_blk_skip, x->blk_skip,
9696              sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
9697       best_xskip = x->skip;
9698       best_disable_skip = *disable_skip;
9699       if (best_xskip) break;
9700     }
9701   }
9702   mbmi->ref_frame[1] = ref_frame_1;
9703   *rate_mv = best_rate_mv;
9704   if (best_rd == INT64_MAX) {
9705     av1_invalid_rd_stats(rd_stats);
9706     restore_dst_buf(xd, *orig_dst, num_planes);
9707     return INT64_MAX;
9708   }
9709   *mbmi = best_mbmi;
9710   *rd_stats = best_rd_stats;
9711   *rd_stats_y = best_rd_stats_y;
9712   if (num_planes > 1) *rd_stats_uv = best_rd_stats_uv;
9713   memcpy(x->blk_skip, best_blk_skip,
9714          sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
9715   x->skip = best_xskip;
9716   *disable_skip = best_disable_skip;
9717 
9718   restore_dst_buf(xd, *orig_dst, num_planes);
9719   return 0;
9720 }
9721 
skip_mode_rd(RD_STATS * rd_stats,const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mi_row,int mi_col,const BUFFER_SET * const orig_dst)9722 static int64_t skip_mode_rd(RD_STATS *rd_stats, const AV1_COMP *const cpi,
9723                             MACROBLOCK *const x, BLOCK_SIZE bsize, int mi_row,
9724                             int mi_col, const BUFFER_SET *const orig_dst) {
9725   const AV1_COMMON *cm = &cpi->common;
9726   const int num_planes = av1_num_planes(cm);
9727   MACROBLOCKD *const xd = &x->e_mbd;
9728   av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize, 0,
9729                                 av1_num_planes(cm) - 1);
9730 
9731   int64_t total_sse = 0;
9732   for (int plane = 0; plane < num_planes; ++plane) {
9733     const struct macroblock_plane *const p = &x->plane[plane];
9734     const struct macroblockd_plane *const pd = &xd->plane[plane];
9735     const BLOCK_SIZE plane_bsize =
9736         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
9737     const int bw = block_size_wide[plane_bsize];
9738     const int bh = block_size_high[plane_bsize];
9739 
9740     av1_subtract_plane(x, bsize, plane);
9741     int64_t sse = aom_sum_squares_2d_i16(p->src_diff, bw, bw, bh) << 4;
9742     total_sse += sse;
9743   }
9744   const int skip_mode_ctx = av1_get_skip_mode_context(xd);
9745   rd_stats->dist = rd_stats->sse = total_sse;
9746   rd_stats->rate = x->skip_mode_cost[skip_mode_ctx][1];
9747   rd_stats->rdcost = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
9748 
9749   restore_dst_buf(xd, *orig_dst, num_planes);
9750   return 0;
9751 }
9752 
get_ref_mv_offset(PREDICTION_MODE single_mode,uint8_t ref_mv_idx)9753 static INLINE int get_ref_mv_offset(PREDICTION_MODE single_mode,
9754                                     uint8_t ref_mv_idx) {
9755   assert(is_inter_singleref_mode(single_mode));
9756   int ref_mv_offset;
9757   if (single_mode == NEARESTMV) {
9758     ref_mv_offset = 0;
9759   } else if (single_mode == NEARMV) {
9760     ref_mv_offset = ref_mv_idx + 1;
9761   } else {
9762     ref_mv_offset = -1;
9763   }
9764   return ref_mv_offset;
9765 }
9766 
get_this_mv(int_mv * this_mv,PREDICTION_MODE this_mode,int ref_idx,int ref_mv_idx,const MV_REFERENCE_FRAME * ref_frame,const MB_MODE_INFO_EXT * mbmi_ext)9767 static INLINE void get_this_mv(int_mv *this_mv, PREDICTION_MODE this_mode,
9768                                int ref_idx, int ref_mv_idx,
9769                                const MV_REFERENCE_FRAME *ref_frame,
9770                                const MB_MODE_INFO_EXT *mbmi_ext) {
9771   const uint8_t ref_frame_type = av1_ref_frame_type(ref_frame);
9772   const int is_comp_pred = ref_frame[1] > INTRA_FRAME;
9773   const PREDICTION_MODE single_mode =
9774       get_single_mode(this_mode, ref_idx, is_comp_pred);
9775   assert(is_inter_singleref_mode(single_mode));
9776   if (single_mode == NEWMV) {
9777     this_mv->as_int = INVALID_MV;
9778   } else if (single_mode == GLOBALMV) {
9779     *this_mv = mbmi_ext->global_mvs[ref_frame[ref_idx]];
9780   } else {
9781     assert(single_mode == NEARMV || single_mode == NEARESTMV);
9782     const int ref_mv_offset = get_ref_mv_offset(single_mode, ref_mv_idx);
9783     if (ref_mv_offset < mbmi_ext->ref_mv_count[ref_frame_type]) {
9784       assert(ref_mv_offset >= 0);
9785       if (ref_idx == 0) {
9786         *this_mv =
9787             mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_offset].this_mv;
9788       } else {
9789         *this_mv =
9790             mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_offset].comp_mv;
9791       }
9792     } else {
9793       *this_mv = mbmi_ext->global_mvs[ref_frame[ref_idx]];
9794     }
9795   }
9796 }
9797 
9798 // This function update the non-new mv for the current prediction mode
build_cur_mv(int_mv * cur_mv,PREDICTION_MODE this_mode,const AV1_COMMON * cm,const MACROBLOCK * x)9799 static INLINE int build_cur_mv(int_mv *cur_mv, PREDICTION_MODE this_mode,
9800                                const AV1_COMMON *cm, const MACROBLOCK *x) {
9801   const MACROBLOCKD *xd = &x->e_mbd;
9802   const MB_MODE_INFO *mbmi = xd->mi[0];
9803   const int is_comp_pred = has_second_ref(mbmi);
9804   int ret = 1;
9805   for (int i = 0; i < is_comp_pred + 1; ++i) {
9806     int_mv this_mv;
9807     get_this_mv(&this_mv, this_mode, i, mbmi->ref_mv_idx, mbmi->ref_frame,
9808                 x->mbmi_ext);
9809     const PREDICTION_MODE single_mode =
9810         get_single_mode(this_mode, i, is_comp_pred);
9811     if (single_mode == NEWMV) {
9812       cur_mv[i] = this_mv;
9813     } else {
9814       ret &= clamp_and_check_mv(cur_mv + i, this_mv, cm, x);
9815     }
9816   }
9817   return ret;
9818 }
9819 
get_drl_cost(const MB_MODE_INFO * mbmi,const MB_MODE_INFO_EXT * mbmi_ext,int (* drl_mode_cost0)[2],int8_t ref_frame_type)9820 static INLINE int get_drl_cost(const MB_MODE_INFO *mbmi,
9821                                const MB_MODE_INFO_EXT *mbmi_ext,
9822                                int (*drl_mode_cost0)[2],
9823                                int8_t ref_frame_type) {
9824   int cost = 0;
9825   if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
9826     for (int idx = 0; idx < 2; ++idx) {
9827       if (mbmi_ext->ref_mv_count[ref_frame_type] > idx + 1) {
9828         uint8_t drl_ctx =
9829             av1_drl_ctx(mbmi_ext->ref_mv_stack[ref_frame_type], idx);
9830         cost += drl_mode_cost0[drl_ctx][mbmi->ref_mv_idx != idx];
9831         if (mbmi->ref_mv_idx == idx) return cost;
9832       }
9833     }
9834     return cost;
9835   }
9836 
9837   if (have_nearmv_in_inter_mode(mbmi->mode)) {
9838     for (int idx = 1; idx < 3; ++idx) {
9839       if (mbmi_ext->ref_mv_count[ref_frame_type] > idx + 1) {
9840         uint8_t drl_ctx =
9841             av1_drl_ctx(mbmi_ext->ref_mv_stack[ref_frame_type], idx);
9842         cost += drl_mode_cost0[drl_ctx][mbmi->ref_mv_idx != (idx - 1)];
9843         if (mbmi->ref_mv_idx == (idx - 1)) return cost;
9844       }
9845     }
9846     return cost;
9847   }
9848   return cost;
9849 }
9850 
9851 // Struct for buffers used by compound_type_rd() function.
9852 // For sizes and alignment of these arrays, refer to
9853 // alloc_compound_type_rd_buffers() function.
9854 typedef struct {
9855   uint8_t *pred0;
9856   uint8_t *pred1;
9857   int16_t *residual1;          // src - pred1
9858   int16_t *diff10;             // pred1 - pred0
9859   uint8_t *tmp_best_mask_buf;  // backup of the best segmentation mask
9860 } CompoundTypeRdBuffers;
9861 
compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_col,int mi_row,int_mv * cur_mv,int mode_search_mask,int masked_compound_used,const BUFFER_SET * orig_dst,const BUFFER_SET * tmp_dst,CompoundTypeRdBuffers * buffers,int * rate_mv,int64_t * rd,RD_STATS * rd_stats,int64_t ref_best_rd,int * is_luma_interp_done)9862 static int compound_type_rd(
9863     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_col,
9864     int mi_row, int_mv *cur_mv, int mode_search_mask, int masked_compound_used,
9865     const BUFFER_SET *orig_dst, const BUFFER_SET *tmp_dst,
9866     CompoundTypeRdBuffers *buffers, int *rate_mv, int64_t *rd,
9867     RD_STATS *rd_stats, int64_t ref_best_rd, int *is_luma_interp_done) {
9868   const AV1_COMMON *cm = &cpi->common;
9869   MACROBLOCKD *xd = &x->e_mbd;
9870   MB_MODE_INFO *mbmi = xd->mi[0];
9871   const PREDICTION_MODE this_mode = mbmi->mode;
9872   const int bw = block_size_wide[bsize];
9873   int rs2;
9874   int_mv best_mv[2];
9875   int best_tmp_rate_mv = *rate_mv;
9876   INTERINTER_COMPOUND_DATA best_compound_data;
9877   best_compound_data.type = COMPOUND_AVERAGE;
9878   uint8_t *preds0[1] = { buffers->pred0 };
9879   uint8_t *preds1[1] = { buffers->pred1 };
9880   int strides[1] = { bw };
9881   int tmp_rate_mv;
9882   const int num_pix = 1 << num_pels_log2_lookup[bsize];
9883   const int mask_len = 2 * num_pix * sizeof(uint8_t);
9884   COMPOUND_TYPE cur_type;
9885   int best_compmode_interinter_cost = 0;
9886   int calc_pred_masked_compound = 1;
9887   int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
9888                                         INT64_MAX };
9889   int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
9890   int64_t comp_model_rd[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
9891                                             INT64_MAX };
9892   const int match_found =
9893       find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rd);
9894 
9895   best_mv[0].as_int = cur_mv[0].as_int;
9896   best_mv[1].as_int = cur_mv[1].as_int;
9897   *rd = INT64_MAX;
9898   int rate_sum, tmp_skip_txfm_sb;
9899   int64_t dist_sum, tmp_skip_sse_sb;
9900   int64_t comp_best_model_rd = INT64_MAX;
9901   // Special handling if both compound_average and compound_distwtd
9902   // are to be searched. In this case, first estimate between the two
9903   // modes and then call estimate_yrd_for_sb() only for the better of
9904   // the two.
9905   const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
9906   const int try_distwtd_comp =
9907       ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
9908        cm->seq_params.order_hint_info.enable_dist_wtd_comp == 1 &&
9909        cpi->sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
9910   const int try_average_and_distwtd_comp =
9911       try_average_comp && try_distwtd_comp &&
9912       comp_rate[COMPOUND_AVERAGE] == INT_MAX &&
9913       comp_rate[COMPOUND_DISTWTD] == INT_MAX;
9914   for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
9915     if (((1 << cur_type) & mode_search_mask) == 0) {
9916       if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
9917       continue;
9918     }
9919     if (!is_interinter_compound_used(cur_type, bsize)) continue;
9920     if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
9921     if (cur_type == COMPOUND_DISTWTD && !try_distwtd_comp) continue;
9922     if (cur_type == COMPOUND_AVERAGE && try_average_and_distwtd_comp) continue;
9923 
9924     int64_t comp_model_rd_cur = INT64_MAX;
9925     tmp_rate_mv = *rate_mv;
9926     int64_t best_rd_cur = INT64_MAX;
9927     const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
9928     const int comp_index_ctx = get_comp_index_context(cm, xd);
9929 
9930     if (cur_type == COMPOUND_DISTWTD && try_average_and_distwtd_comp) {
9931       int est_rate[2];
9932       int64_t est_dist[2], est_rd[2];
9933 
9934       int masked_type_cost[2] = { 0, 0 };
9935       mbmi->comp_group_idx = 0;
9936 
9937       // First find the modeled rd cost for COMPOUND_AVERAGE
9938       mbmi->interinter_comp.type = COMPOUND_AVERAGE;
9939       mbmi->compound_idx = 1;
9940       if (masked_compound_used) {
9941         masked_type_cost[COMPOUND_AVERAGE] +=
9942             x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
9943       }
9944       masked_type_cost[COMPOUND_AVERAGE] +=
9945           x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
9946       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
9947                                     AOM_PLANE_Y, AOM_PLANE_Y);
9948       *is_luma_interp_done = 1;
9949       model_rd_sb_fn[MODELRD_CURVFIT](
9950           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[COMPOUND_AVERAGE],
9951           &est_dist[COMPOUND_AVERAGE], NULL, NULL, NULL, NULL, NULL);
9952       est_rate[COMPOUND_AVERAGE] += masked_type_cost[COMPOUND_AVERAGE];
9953       est_rd[COMPOUND_AVERAGE] =
9954           RDCOST(x->rdmult, est_rate[COMPOUND_AVERAGE] + *rate_mv,
9955                  est_dist[COMPOUND_AVERAGE]);
9956       restore_dst_buf(xd, *tmp_dst, 1);
9957 
9958       // Next find the modeled rd cost for COMPOUND_DISTWTD
9959       mbmi->interinter_comp.type = COMPOUND_DISTWTD;
9960       mbmi->compound_idx = 0;
9961       if (masked_compound_used) {
9962         masked_type_cost[COMPOUND_DISTWTD] +=
9963             x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
9964       }
9965       masked_type_cost[COMPOUND_DISTWTD] +=
9966           x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
9967       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
9968                                     AOM_PLANE_Y, AOM_PLANE_Y);
9969       model_rd_sb_fn[MODELRD_CURVFIT](
9970           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[COMPOUND_DISTWTD],
9971           &est_dist[COMPOUND_DISTWTD], NULL, NULL, NULL, NULL, NULL);
9972       est_rate[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_DISTWTD];
9973       est_rd[COMPOUND_DISTWTD] =
9974           RDCOST(x->rdmult, est_rate[COMPOUND_DISTWTD] + *rate_mv,
9975                  est_dist[COMPOUND_DISTWTD]);
9976 
9977       // Choose the better of the two based on modeled cost and call
9978       // estimate_yrd_for_sb() for that one.
9979       if (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD]) {
9980         mbmi->interinter_comp.type = COMPOUND_AVERAGE;
9981         mbmi->compound_idx = 1;
9982         restore_dst_buf(xd, *orig_dst, 1);
9983         RD_STATS est_rd_stats;
9984         const int64_t est_rd_ =
9985             estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
9986         rs2 = masked_type_cost[COMPOUND_AVERAGE];
9987         if (est_rd_ != INT64_MAX) {
9988           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
9989                                est_rd_stats.dist);
9990           restore_dst_buf(xd, *tmp_dst, 1);
9991           comp_rate[COMPOUND_AVERAGE] = est_rd_stats.rate;
9992           comp_dist[COMPOUND_AVERAGE] = est_rd_stats.dist;
9993           comp_model_rd[COMPOUND_AVERAGE] = est_rd[COMPOUND_AVERAGE];
9994           comp_model_rd_cur = est_rd[COMPOUND_AVERAGE];
9995         }
9996         restore_dst_buf(xd, *tmp_dst, 1);
9997       } else {
9998         RD_STATS est_rd_stats;
9999         const int64_t est_rd_ =
10000             estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
10001         rs2 = masked_type_cost[COMPOUND_DISTWTD];
10002         if (est_rd_ != INT64_MAX) {
10003           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
10004                                est_rd_stats.dist);
10005           comp_rate[COMPOUND_DISTWTD] = est_rd_stats.rate;
10006           comp_dist[COMPOUND_DISTWTD] = est_rd_stats.dist;
10007           comp_model_rd[COMPOUND_DISTWTD] = est_rd[COMPOUND_DISTWTD];
10008           comp_model_rd_cur = est_rd[COMPOUND_DISTWTD];
10009         }
10010       }
10011     } else {
10012       mbmi->interinter_comp.type = cur_type;
10013       int masked_type_cost = 0;
10014       if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
10015         mbmi->comp_group_idx = 0;
10016         mbmi->compound_idx = (cur_type == COMPOUND_AVERAGE);
10017         if (masked_compound_used) {
10018           masked_type_cost +=
10019               x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
10020         }
10021         masked_type_cost +=
10022             x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
10023         rs2 = masked_type_cost;
10024         const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
10025         if (mode_rd < ref_best_rd) {
10026           // Reuse data if matching record is found
10027           if (comp_rate[cur_type] == INT_MAX) {
10028             av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst,
10029                                           bsize, AOM_PLANE_Y, AOM_PLANE_Y);
10030             if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
10031             RD_STATS est_rd_stats;
10032             const int64_t est_rd =
10033                 estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
10034             if (comp_rate[cur_type] != INT_MAX) {
10035               assert(comp_rate[cur_type] == est_rd_stats.rate);
10036               assert(comp_dist[cur_type] == est_rd_stats.dist);
10037             }
10038             if (est_rd != INT64_MAX) {
10039               best_rd_cur =
10040                   RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
10041                          est_rd_stats.dist);
10042               model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
10043                   cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
10044                   &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
10045               comp_model_rd_cur =
10046                   RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
10047 
10048               // Backup rate and distortion for future reuse
10049               comp_rate[cur_type] = est_rd_stats.rate;
10050               comp_dist[cur_type] = est_rd_stats.dist;
10051               comp_model_rd[cur_type] = comp_model_rd_cur;
10052             }
10053           } else {
10054             // Calculate RD cost based on stored stats
10055             assert(comp_dist[cur_type] != INT64_MAX);
10056             best_rd_cur =
10057                 RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
10058                        comp_dist[cur_type]);
10059             comp_model_rd_cur = comp_model_rd[cur_type];
10060           }
10061         }
10062         // use spare buffer for following compound type try
10063         if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
10064       } else {
10065         mbmi->comp_group_idx = 1;
10066         mbmi->compound_idx = 1;
10067         masked_type_cost +=
10068             x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
10069         masked_type_cost +=
10070             x->compound_type_cost[bsize][cur_type - COMPOUND_WEDGE];
10071         rs2 = masked_type_cost;
10072 
10073         if (((*rd / cpi->max_comp_type_rd_threshold_div) *
10074              cpi->max_comp_type_rd_threshold_mul) < ref_best_rd) {
10075           const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
10076 
10077           if (!((compound_type == COMPOUND_WEDGE &&
10078                  !enable_wedge_interinter_search(x, cpi)) ||
10079                 (compound_type == COMPOUND_DIFFWTD &&
10080                  !cpi->oxcf.enable_diff_wtd_comp)))
10081             best_rd_cur = build_and_cost_compound_type(
10082                 cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
10083                 &tmp_rate_mv, preds0, preds1, buffers->residual1,
10084                 buffers->diff10, strides, mi_row, mi_col, rd_stats->rate,
10085                 ref_best_rd, &calc_pred_masked_compound, comp_rate, comp_dist,
10086                 comp_model_rd, comp_best_model_rd, &comp_model_rd_cur);
10087         }
10088       }
10089     }
10090     if (best_rd_cur < *rd) {
10091       *rd = best_rd_cur;
10092       comp_best_model_rd = comp_model_rd_cur;
10093       best_compound_data = mbmi->interinter_comp;
10094       if (masked_compound_used && cur_type >= COMPOUND_WEDGE) {
10095         memcpy(buffers->tmp_best_mask_buf, xd->seg_mask, mask_len);
10096       }
10097       best_compmode_interinter_cost = rs2;
10098       if (have_newmv_in_inter_mode(this_mode)) {
10099         if (cur_type == COMPOUND_WEDGE) {
10100           best_tmp_rate_mv = tmp_rate_mv;
10101           best_mv[0].as_int = mbmi->mv[0].as_int;
10102           best_mv[1].as_int = mbmi->mv[1].as_int;
10103         } else {
10104           best_mv[0].as_int = cur_mv[0].as_int;
10105           best_mv[1].as_int = cur_mv[1].as_int;
10106         }
10107       }
10108     }
10109     // reset to original mvs for next iteration
10110     mbmi->mv[0].as_int = cur_mv[0].as_int;
10111     mbmi->mv[1].as_int = cur_mv[1].as_int;
10112   }
10113   if (mbmi->interinter_comp.type != best_compound_data.type) {
10114     mbmi->comp_group_idx = (best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
10115     mbmi->compound_idx = !(best_compound_data.type == COMPOUND_DISTWTD);
10116     mbmi->interinter_comp = best_compound_data;
10117     memcpy(xd->seg_mask, buffers->tmp_best_mask_buf, mask_len);
10118   }
10119   if (have_newmv_in_inter_mode(this_mode)) {
10120     mbmi->mv[0].as_int = best_mv[0].as_int;
10121     mbmi->mv[1].as_int = best_mv[1].as_int;
10122     if (mbmi->interinter_comp.type == COMPOUND_WEDGE) {
10123       rd_stats->rate += best_tmp_rate_mv - *rate_mv;
10124       *rate_mv = best_tmp_rate_mv;
10125     }
10126   }
10127   restore_dst_buf(xd, *orig_dst, 1);
10128   if (!match_found)
10129     save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rd,
10130                              cur_mv);
10131   return best_compmode_interinter_cost;
10132 }
10133 
is_single_newmv_valid(HandleInterModeArgs * args,MB_MODE_INFO * mbmi,PREDICTION_MODE this_mode)10134 static INLINE int is_single_newmv_valid(HandleInterModeArgs *args,
10135                                         MB_MODE_INFO *mbmi,
10136                                         PREDICTION_MODE this_mode) {
10137   for (int ref_idx = 0; ref_idx < 2; ++ref_idx) {
10138     const PREDICTION_MODE single_mode = get_single_mode(this_mode, ref_idx, 1);
10139     const MV_REFERENCE_FRAME ref = mbmi->ref_frame[ref_idx];
10140     if (single_mode == NEWMV &&
10141         args->single_newmv_valid[mbmi->ref_mv_idx][ref] == 0) {
10142       return 0;
10143     }
10144   }
10145   return 1;
10146 }
10147 
get_drl_refmv_count(const MACROBLOCK * const x,const MV_REFERENCE_FRAME * ref_frame,PREDICTION_MODE mode)10148 static int get_drl_refmv_count(const MACROBLOCK *const x,
10149                                const MV_REFERENCE_FRAME *ref_frame,
10150                                PREDICTION_MODE mode) {
10151   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
10152   const int8_t ref_frame_type = av1_ref_frame_type(ref_frame);
10153   const int has_nearmv = have_nearmv_in_inter_mode(mode) ? 1 : 0;
10154   const int ref_mv_count = mbmi_ext->ref_mv_count[ref_frame_type];
10155   const int only_newmv = (mode == NEWMV || mode == NEW_NEWMV);
10156   const int has_drl =
10157       (has_nearmv && ref_mv_count > 2) || (only_newmv && ref_mv_count > 1);
10158   const int ref_set =
10159       has_drl ? AOMMIN(MAX_REF_MV_SERCH, ref_mv_count - has_nearmv) : 1;
10160 
10161   return ref_set;
10162 }
10163 
10164 typedef struct {
10165   int64_t rd;
10166   int drl_cost;
10167   int rate_mv;
10168   int_mv mv;
10169 } inter_mode_info;
10170 
handle_inter_mode(AV1_COMP * const cpi,TileDataEnc * tile_data,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int * disable_skip,int mi_row,int mi_col,HandleInterModeArgs * args,int64_t ref_best_rd,uint8_t * const tmp_buf,CompoundTypeRdBuffers * rd_buffers,int64_t * best_est_rd,const int do_tx_search,InterModesInfo * inter_modes_info)10171 static int64_t handle_inter_mode(
10172     AV1_COMP *const cpi, TileDataEnc *tile_data, MACROBLOCK *x,
10173     BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y,
10174     RD_STATS *rd_stats_uv, int *disable_skip, int mi_row, int mi_col,
10175     HandleInterModeArgs *args, int64_t ref_best_rd, uint8_t *const tmp_buf,
10176     CompoundTypeRdBuffers *rd_buffers, int64_t *best_est_rd,
10177     const int do_tx_search, InterModesInfo *inter_modes_info) {
10178   const AV1_COMMON *cm = &cpi->common;
10179   const int num_planes = av1_num_planes(cm);
10180   MACROBLOCKD *xd = &x->e_mbd;
10181   MB_MODE_INFO *mbmi = xd->mi[0];
10182   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
10183   const int is_comp_pred = has_second_ref(mbmi);
10184   const PREDICTION_MODE this_mode = mbmi->mode;
10185   int i;
10186   int refs[2] = { mbmi->ref_frame[0],
10187                   (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
10188   int rate_mv = 0;
10189   int64_t rd = INT64_MAX;
10190 
10191   // do first prediction into the destination buffer. Do the next
10192   // prediction into a temporary buffer. Then keep track of which one
10193   // of these currently holds the best predictor, and use the other
10194   // one for future predictions. In the end, copy from tmp_buf to
10195   // dst if necessary.
10196   struct macroblockd_plane *p = xd->plane;
10197   const BUFFER_SET orig_dst = {
10198     { p[0].dst.buf, p[1].dst.buf, p[2].dst.buf },
10199     { p[0].dst.stride, p[1].dst.stride, p[2].dst.stride },
10200   };
10201   const BUFFER_SET tmp_dst = { { tmp_buf, tmp_buf + 1 * MAX_SB_SQUARE,
10202                                  tmp_buf + 2 * MAX_SB_SQUARE },
10203                                { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE } };
10204 
10205   int skip_txfm_sb = 0;
10206   int64_t skip_sse_sb = INT64_MAX;
10207   int16_t mode_ctx;
10208   const int masked_compound_used = is_any_masked_compound_used(bsize) &&
10209                                    cm->seq_params.enable_masked_compound;
10210   int64_t ret_val = INT64_MAX;
10211   const int8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
10212   RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv;
10213   int64_t best_rd = INT64_MAX;
10214   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
10215   MB_MODE_INFO best_mbmi = *mbmi;
10216   int best_disable_skip;
10217   int best_xskip;
10218   int64_t newmv_ret_val = INT64_MAX;
10219   int_mv backup_mv[2] = { { 0 } };
10220   int backup_rate_mv = 0;
10221   inter_mode_info mode_info[MAX_REF_MV_SERCH];
10222 
10223   int mode_search_mask[2];
10224   const int do_two_loop_comp_search =
10225       is_comp_pred && cpi->sf.two_loop_comp_search;
10226   if (do_two_loop_comp_search) {
10227     // TODO(debargha): Change this to try alternate ways of splitting
10228     // modes while doing two pass compound_mode search.
10229     mode_search_mask[0] = (1 << COMPOUND_AVERAGE);
10230   } else {
10231     mode_search_mask[0] = (1 << COMPOUND_AVERAGE) | (1 << COMPOUND_DISTWTD) |
10232                           (1 << COMPOUND_WEDGE) | (1 << COMPOUND_DIFFWTD);
10233   }
10234   mode_search_mask[1] = ((1 << COMPOUND_AVERAGE) | (1 << COMPOUND_DISTWTD) |
10235                          (1 << COMPOUND_WEDGE) | (1 << COMPOUND_DIFFWTD)) -
10236                         mode_search_mask[0];
10237 
10238   // TODO(jingning): This should be deprecated shortly.
10239   const int has_nearmv = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0;
10240   const int ref_set = get_drl_refmv_count(x, mbmi->ref_frame, this_mode);
10241 
10242   for (int ref_mv_idx = 0; ref_mv_idx < ref_set; ++ref_mv_idx) {
10243     mode_info[ref_mv_idx].mv.as_int = INVALID_MV;
10244     mode_info[ref_mv_idx].rd = INT64_MAX;
10245 
10246     if (cpi->sf.reduce_inter_modes && ref_mv_idx > 0) {
10247       if (mbmi->ref_frame[0] == LAST2_FRAME ||
10248           mbmi->ref_frame[0] == LAST3_FRAME ||
10249           mbmi->ref_frame[1] == LAST2_FRAME ||
10250           mbmi->ref_frame[1] == LAST3_FRAME) {
10251         if (mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx + has_nearmv]
10252                 .weight < REF_CAT_LEVEL) {
10253           continue;
10254         }
10255       }
10256     }
10257     if (cpi->sf.prune_single_motion_modes_by_simple_trans && !is_comp_pred &&
10258         args->single_ref_first_pass == 0) {
10259       if (args->simple_rd_state[ref_mv_idx].early_skipped) {
10260         continue;
10261       }
10262     }
10263     av1_init_rd_stats(rd_stats);
10264 
10265     mbmi->interinter_comp.type = COMPOUND_AVERAGE;
10266     mbmi->comp_group_idx = 0;
10267     mbmi->compound_idx = 1;
10268     if (mbmi->ref_frame[1] == INTRA_FRAME) mbmi->ref_frame[1] = NONE_FRAME;
10269 
10270     mode_ctx =
10271         av1_mode_context_analyzer(mbmi_ext->mode_context, mbmi->ref_frame);
10272 
10273     mbmi->num_proj_ref = 0;
10274     mbmi->motion_mode = SIMPLE_TRANSLATION;
10275     mbmi->ref_mv_idx = ref_mv_idx;
10276 
10277     if (is_comp_pred && (!is_single_newmv_valid(args, mbmi, this_mode))) {
10278       continue;
10279     }
10280 
10281     rd_stats->rate += args->ref_frame_cost + args->single_comp_cost;
10282     const int drl_cost =
10283         get_drl_cost(mbmi, mbmi_ext, x->drl_mode_cost0, ref_frame_type);
10284     rd_stats->rate += drl_cost;
10285     mode_info[ref_mv_idx].drl_cost = drl_cost;
10286 
10287     if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd &&
10288         mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) {
10289       continue;
10290     }
10291 
10292     const RD_STATS backup_rd_stats = *rd_stats;
10293 
10294     for (int comp_loop_idx = 0; comp_loop_idx <= do_two_loop_comp_search;
10295          ++comp_loop_idx) {
10296       int rs = 0;
10297       int compmode_interinter_cost = 0;
10298 
10299       if (is_comp_pred && comp_loop_idx == 1) *rd_stats = backup_rd_stats;
10300 
10301       int_mv cur_mv[2];
10302       if (!build_cur_mv(cur_mv, this_mode, cm, x)) {
10303         continue;
10304       }
10305       if (have_newmv_in_inter_mode(this_mode)) {
10306         if (comp_loop_idx == 1) {
10307           cur_mv[0] = backup_mv[0];
10308           cur_mv[1] = backup_mv[1];
10309           rate_mv = backup_rate_mv;
10310         }
10311 
10312 #if CONFIG_COLLECT_COMPONENT_TIMING
10313         start_timing(cpi, handle_newmv_time);
10314 #endif
10315         if (cpi->sf.prune_single_motion_modes_by_simple_trans &&
10316             args->single_ref_first_pass == 0 && !is_comp_pred) {
10317           const int ref0 = mbmi->ref_frame[0];
10318           newmv_ret_val = args->single_newmv_valid[ref_mv_idx][ref0] ? 0 : 1;
10319           cur_mv[0] = args->single_newmv[ref_mv_idx][ref0];
10320           rate_mv = args->single_newmv_rate[ref_mv_idx][ref0];
10321         } else if (comp_loop_idx == 0) {
10322           newmv_ret_val = handle_newmv(cpi, x, bsize, cur_mv, mi_row, mi_col,
10323                                        &rate_mv, args);
10324 
10325           // Store cur_mv and rate_mv so that they can be restored in the next
10326           // iteration of the loop
10327           backup_mv[0] = cur_mv[0];
10328           backup_mv[1] = cur_mv[1];
10329           backup_rate_mv = rate_mv;
10330         }
10331 #if CONFIG_COLLECT_COMPONENT_TIMING
10332         end_timing(cpi, handle_newmv_time);
10333 #endif
10334 
10335         if (newmv_ret_val != 0) {
10336           continue;
10337         } else {
10338           rd_stats->rate += rate_mv;
10339         }
10340 
10341         if (cpi->sf.skip_repeated_newmv) {
10342           if (!is_comp_pred && this_mode == NEWMV && ref_mv_idx > 0) {
10343             int skip = 0;
10344             int this_rate_mv = 0;
10345             for (i = 0; i < ref_mv_idx; ++i) {
10346               // Check if the motion search result same as previous results
10347               if (cur_mv[0].as_int == args->single_newmv[i][refs[0]].as_int) {
10348                 // If the compared mode has no valid rd, it is unlikely this
10349                 // mode will be the best mode
10350                 if (mode_info[i].rd == INT64_MAX) {
10351                   skip = 1;
10352                   break;
10353                 }
10354                 // Compare the cost difference including drl cost and mv cost
10355                 if (mode_info[i].mv.as_int != INVALID_MV) {
10356                   const int compare_cost =
10357                       mode_info[i].rate_mv + mode_info[i].drl_cost;
10358                   const int_mv ref_mv = av1_get_ref_mv(x, 0);
10359                   this_rate_mv = av1_mv_bit_cost(
10360                       &mode_info[i].mv.as_mv, &ref_mv.as_mv, x->nmv_vec_cost,
10361                       x->mv_cost_stack, MV_COST_WEIGHT);
10362                   const int this_cost = this_rate_mv + drl_cost;
10363 
10364                   if (compare_cost < this_cost) {
10365                     skip = 1;
10366                     break;
10367                   } else {
10368                     // If the cost is less than current best result, make this
10369                     // the best and update corresponding variables
10370                     if (best_mbmi.ref_mv_idx == i) {
10371                       assert(best_rd != INT64_MAX);
10372                       best_mbmi.ref_mv_idx = ref_mv_idx;
10373                       best_rd_stats.rate += this_cost - compare_cost;
10374                       best_rd = RDCOST(x->rdmult, best_rd_stats.rate,
10375                                        best_rd_stats.dist);
10376                       if (best_rd < ref_best_rd) ref_best_rd = best_rd;
10377                       skip = 1;
10378                       break;
10379                     }
10380                   }
10381                 }
10382               }
10383             }
10384             if (skip) {
10385               args->modelled_rd[this_mode][ref_mv_idx][refs[0]] =
10386                   args->modelled_rd[this_mode][i][refs[0]];
10387               args->simple_rd[this_mode][ref_mv_idx][refs[0]] =
10388                   args->simple_rd[this_mode][i][refs[0]];
10389               mode_info[ref_mv_idx].rd = mode_info[i].rd;
10390               mode_info[ref_mv_idx].rate_mv = this_rate_mv;
10391               mode_info[ref_mv_idx].mv.as_int = mode_info[i].mv.as_int;
10392 
10393               restore_dst_buf(xd, orig_dst, num_planes);
10394               continue;
10395             }
10396           }
10397         }
10398       }
10399       for (i = 0; i < is_comp_pred + 1; ++i) {
10400         mbmi->mv[i].as_int = cur_mv[i].as_int;
10401       }
10402       const int ref_mv_cost = cost_mv_ref(x, this_mode, mode_ctx);
10403 #if USE_DISCOUNT_NEWMV_TEST
10404       // We don't include the cost of the second reference here, because there
10405       // are only three options: Last/Golden, ARF/Last or Golden/ARF, or in
10406       // other words if you present them in that order, the second one is always
10407       // known if the first is known.
10408       //
10409       // Under some circumstances we discount the cost of new mv mode to
10410       // encourage initiation of a motion field.
10411       if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) {
10412         // discount_newmv_test only applies discount on NEWMV mode.
10413         assert(this_mode == NEWMV);
10414         rd_stats->rate += AOMMIN(cost_mv_ref(x, this_mode, mode_ctx),
10415                                  cost_mv_ref(x, NEARESTMV, mode_ctx));
10416       } else {
10417         rd_stats->rate += ref_mv_cost;
10418       }
10419 #else
10420       rd_stats->rate += ref_mv_cost;
10421 #endif
10422 
10423       if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd &&
10424           mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) {
10425         continue;
10426       }
10427 
10428 #if CONFIG_COLLECT_COMPONENT_TIMING
10429       start_timing(cpi, compound_type_rd_time);
10430 #endif
10431       int skip_build_pred = 0;
10432       if (is_comp_pred) {
10433         if (mode_search_mask[comp_loop_idx] == (1 << COMPOUND_AVERAGE)) {
10434           // Only compound_average
10435           mbmi->interinter_comp.type = COMPOUND_AVERAGE;
10436           mbmi->num_proj_ref = 0;
10437           mbmi->motion_mode = SIMPLE_TRANSLATION;
10438           mbmi->comp_group_idx = 0;
10439           mbmi->compound_idx = 1;
10440           const int comp_index_ctx = get_comp_index_context(cm, xd);
10441           compmode_interinter_cost +=
10442               x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
10443         } else if (mode_search_mask[comp_loop_idx] == (1 << COMPOUND_DISTWTD)) {
10444           // Only compound_distwtd
10445           if (!cm->seq_params.order_hint_info.enable_dist_wtd_comp ||
10446               cpi->sf.use_dist_wtd_comp_flag == DIST_WTD_COMP_DISABLED ||
10447               (do_two_loop_comp_search && mbmi->mode == GLOBAL_GLOBALMV))
10448             continue;
10449           mbmi->interinter_comp.type = COMPOUND_DISTWTD;
10450           mbmi->num_proj_ref = 0;
10451           mbmi->motion_mode = SIMPLE_TRANSLATION;
10452           mbmi->comp_group_idx = 0;
10453           mbmi->compound_idx = 0;
10454           const int comp_index_ctx = get_comp_index_context(cm, xd);
10455           compmode_interinter_cost +=
10456               x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
10457         } else {
10458           // Find matching interp filter or set to default interp filter
10459           const int need_search =
10460               av1_is_interp_needed(xd) && av1_is_interp_search_needed(xd);
10461           int match_found = -1;
10462           const InterpFilter assign_filter = cm->interp_filter;
10463           int is_luma_interp_done = 0;
10464           if (cpi->sf.skip_repeat_interpolation_filter_search && need_search) {
10465             match_found = find_interp_filter_in_stats(x, mbmi);
10466           }
10467           if (!need_search || match_found == -1) {
10468             set_default_interp_filters(mbmi, assign_filter);
10469           }
10470 
10471           int64_t best_rd_compound;
10472           compmode_interinter_cost = compound_type_rd(
10473               cpi, x, bsize, mi_col, mi_row, cur_mv,
10474               mode_search_mask[comp_loop_idx], masked_compound_used, &orig_dst,
10475               &tmp_dst, rd_buffers, &rate_mv, &best_rd_compound, rd_stats,
10476               ref_best_rd, &is_luma_interp_done);
10477           if (ref_best_rd < INT64_MAX &&
10478               (best_rd_compound >> 4) * (11 + 2 * do_two_loop_comp_search) >
10479                   ref_best_rd) {
10480             restore_dst_buf(xd, orig_dst, num_planes);
10481             continue;
10482           }
10483           // No need to call av1_enc_build_inter_predictor for luma if
10484           // COMPOUND_AVERAGE is selected because it is the first
10485           // candidate in compound_type_rd, and the following
10486           // compound types searching uses tmp_dst buffer
10487 
10488           if (mbmi->interinter_comp.type == COMPOUND_AVERAGE &&
10489               is_luma_interp_done) {
10490             if (num_planes > 1) {
10491               av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, &orig_dst,
10492                                             bsize, AOM_PLANE_U, num_planes - 1);
10493             }
10494             skip_build_pred = 1;
10495           }
10496         }
10497       }
10498 #if CONFIG_COLLECT_COMPONENT_TIMING
10499       end_timing(cpi, compound_type_rd_time);
10500 #endif
10501 
10502 #if CONFIG_COLLECT_COMPONENT_TIMING
10503       start_timing(cpi, interpolation_filter_search_time);
10504 #endif
10505       ret_val = interpolation_filter_search(
10506           x, cpi, tile_data, bsize, mi_row, mi_col, &tmp_dst, &orig_dst,
10507           args->single_filter, &rd, &rs, &skip_txfm_sb, &skip_sse_sb,
10508           &skip_build_pred, args, ref_best_rd);
10509 #if CONFIG_COLLECT_COMPONENT_TIMING
10510       end_timing(cpi, interpolation_filter_search_time);
10511 #endif
10512       if (args->modelled_rd != NULL && !is_comp_pred) {
10513         args->modelled_rd[this_mode][ref_mv_idx][refs[0]] = rd;
10514       }
10515       if (ret_val != 0) {
10516         restore_dst_buf(xd, orig_dst, num_planes);
10517         continue;
10518       } else if (cpi->sf.model_based_post_interp_filter_breakout &&
10519                  ref_best_rd != INT64_MAX && (rd >> 3) * 3 > ref_best_rd) {
10520         restore_dst_buf(xd, orig_dst, num_planes);
10521         break;
10522       }
10523 
10524       if (!is_comp_pred)
10525         args->single_filter[this_mode][refs[0]] =
10526             av1_extract_interp_filter(mbmi->interp_filters, 0);
10527 
10528       if (args->modelled_rd != NULL) {
10529         if (is_comp_pred) {
10530           const int mode0 = compound_ref0_mode(this_mode);
10531           const int mode1 = compound_ref1_mode(this_mode);
10532           const int64_t mrd =
10533               AOMMIN(args->modelled_rd[mode0][ref_mv_idx][refs[0]],
10534                      args->modelled_rd[mode1][ref_mv_idx][refs[1]]);
10535           if ((rd >> 3) * 6 > mrd && ref_best_rd < INT64_MAX) {
10536             restore_dst_buf(xd, orig_dst, num_planes);
10537             continue;
10538           }
10539         }
10540       }
10541       rd_stats->rate += compmode_interinter_cost;
10542       if (skip_build_pred != 1) {
10543         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, &orig_dst, bsize,
10544                                       0, av1_num_planes(cm) - 1);
10545       }
10546 
10547       if (cpi->sf.second_loop_comp_fast_tx_search && comp_loop_idx == 1) {
10548         // TODO(chengchen): this speed feature introduces big loss.
10549         // Need better estimation of rate distortion.
10550         int dummy_rate;
10551         int64_t dummy_dist;
10552         int plane_rate[MAX_MB_PLANE] = { 0 };
10553         int64_t plane_sse[MAX_MB_PLANE] = { 0 };
10554         int64_t plane_dist[MAX_MB_PLANE] = { 0 };
10555 
10556         model_rd_sb_fn[MODELRD_TYPE_DIST_WTD_COMPOUND](
10557             cpi, bsize, x, xd, 0, num_planes - 1, mi_row, mi_col, &dummy_rate,
10558             &dummy_dist, &skip_txfm_sb, &skip_sse_sb, plane_rate, plane_sse,
10559             plane_dist);
10560 
10561         rd_stats->rate += rs;
10562         rd_stats->rate += plane_rate[0] + plane_rate[1] + plane_rate[2];
10563         rd_stats_y->rate = plane_rate[0];
10564         rd_stats_uv->rate = plane_rate[1] + plane_rate[2];
10565         rd_stats->sse = plane_sse[0] + plane_sse[1] + plane_sse[2];
10566         rd_stats_y->sse = plane_sse[0];
10567         rd_stats_uv->sse = plane_sse[1] + plane_sse[2];
10568         rd_stats->dist = plane_dist[0] + plane_dist[1] + plane_dist[2];
10569         rd_stats_y->dist = plane_dist[0];
10570         rd_stats_uv->dist = plane_dist[1] + plane_dist[2];
10571       } else {
10572 #if CONFIG_COLLECT_COMPONENT_TIMING
10573         start_timing(cpi, motion_mode_rd_time);
10574 #endif
10575         ret_val = motion_mode_rd(cpi, tile_data, x, bsize, rd_stats, rd_stats_y,
10576                                  rd_stats_uv, disable_skip, mi_row, mi_col,
10577                                  args, ref_best_rd, refs, &rate_mv, &orig_dst,
10578                                  best_est_rd, do_tx_search, inter_modes_info);
10579 #if CONFIG_COLLECT_COMPONENT_TIMING
10580         end_timing(cpi, motion_mode_rd_time);
10581 #endif
10582       }
10583       mode_info[ref_mv_idx].mv.as_int = mbmi->mv[0].as_int;
10584       mode_info[ref_mv_idx].rate_mv = rate_mv;
10585       if (ret_val != INT64_MAX) {
10586         int64_t tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
10587         mode_info[ref_mv_idx].rd = tmp_rd;
10588         if (tmp_rd < best_rd) {
10589           best_rd_stats = *rd_stats;
10590           best_rd_stats_y = *rd_stats_y;
10591           best_rd_stats_uv = *rd_stats_uv;
10592           best_rd = tmp_rd;
10593           best_mbmi = *mbmi;
10594           best_disable_skip = *disable_skip;
10595           best_xskip = x->skip;
10596           memcpy(best_blk_skip, x->blk_skip,
10597                  sizeof(best_blk_skip[0]) * xd->n4_h * xd->n4_w);
10598         }
10599 
10600         if (tmp_rd < ref_best_rd) {
10601           ref_best_rd = tmp_rd;
10602         }
10603       }
10604       restore_dst_buf(xd, orig_dst, num_planes);
10605     }
10606   }
10607 
10608   if (best_rd == INT64_MAX) return INT64_MAX;
10609 
10610   // re-instate status of the best choice
10611   *rd_stats = best_rd_stats;
10612   *rd_stats_y = best_rd_stats_y;
10613   *rd_stats_uv = best_rd_stats_uv;
10614   *mbmi = best_mbmi;
10615   *disable_skip = best_disable_skip;
10616   x->skip = best_xskip;
10617   assert(IMPLIES(mbmi->comp_group_idx == 1,
10618                  mbmi->interinter_comp.type != COMPOUND_AVERAGE));
10619   memcpy(x->blk_skip, best_blk_skip,
10620          sizeof(best_blk_skip[0]) * xd->n4_h * xd->n4_w);
10621 
10622   return RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
10623 }
10624 
rd_pick_intrabc_mode_sb(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t best_rd)10625 static int64_t rd_pick_intrabc_mode_sb(const AV1_COMP *cpi, MACROBLOCK *x,
10626                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
10627                                        int64_t best_rd) {
10628   const AV1_COMMON *const cm = &cpi->common;
10629   if (!av1_allow_intrabc(cm) || !cpi->oxcf.enable_intrabc) return INT64_MAX;
10630   const int num_planes = av1_num_planes(cm);
10631 
10632   MACROBLOCKD *const xd = &x->e_mbd;
10633   const TileInfo *tile = &xd->tile;
10634   MB_MODE_INFO *mbmi = xd->mi[0];
10635   const int mi_row = -xd->mb_to_top_edge / (8 * MI_SIZE);
10636   const int mi_col = -xd->mb_to_left_edge / (8 * MI_SIZE);
10637   const int w = block_size_wide[bsize];
10638   const int h = block_size_high[bsize];
10639   const int sb_row = mi_row >> cm->seq_params.mib_size_log2;
10640   const int sb_col = mi_col >> cm->seq_params.mib_size_log2;
10641 
10642   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
10643   MV_REFERENCE_FRAME ref_frame = INTRA_FRAME;
10644   av1_find_mv_refs(cm, xd, mbmi, ref_frame, mbmi_ext->ref_mv_count,
10645                    mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
10646                    mi_col, mbmi_ext->mode_context);
10647 
10648   int_mv nearestmv, nearmv;
10649   av1_find_best_ref_mvs_from_stack(0, mbmi_ext, ref_frame, &nearestmv, &nearmv,
10650                                    0);
10651 
10652   if (nearestmv.as_int == INVALID_MV) {
10653     nearestmv.as_int = 0;
10654   }
10655   if (nearmv.as_int == INVALID_MV) {
10656     nearmv.as_int = 0;
10657   }
10658 
10659   int_mv dv_ref = nearestmv.as_int == 0 ? nearmv : nearestmv;
10660   if (dv_ref.as_int == 0)
10661     av1_find_ref_dv(&dv_ref, tile, cm->seq_params.mib_size, mi_row, mi_col);
10662   // Ref DV should not have sub-pel.
10663   assert((dv_ref.as_mv.col & 7) == 0);
10664   assert((dv_ref.as_mv.row & 7) == 0);
10665   mbmi_ext->ref_mv_stack[INTRA_FRAME][0].this_mv = dv_ref;
10666 
10667   struct buf_2d yv12_mb[MAX_MB_PLANE];
10668   av1_setup_pred_block(xd, yv12_mb, xd->cur_buf, mi_row, mi_col, NULL, NULL,
10669                        num_planes);
10670   for (int i = 0; i < num_planes; ++i) {
10671     xd->plane[i].pre[0] = yv12_mb[i];
10672   }
10673 
10674   enum IntrabcMotionDirection {
10675     IBC_MOTION_ABOVE,
10676     IBC_MOTION_LEFT,
10677     IBC_MOTION_DIRECTIONS
10678   };
10679 
10680   MB_MODE_INFO best_mbmi = *mbmi;
10681   RD_STATS best_rdstats = *rd_stats;
10682   int best_skip = x->skip;
10683 
10684   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE] = { 0 };
10685   for (enum IntrabcMotionDirection dir = IBC_MOTION_ABOVE;
10686        dir < IBC_MOTION_DIRECTIONS; ++dir) {
10687     const MvLimits tmp_mv_limits = x->mv_limits;
10688     switch (dir) {
10689       case IBC_MOTION_ABOVE:
10690         x->mv_limits.col_min = (tile->mi_col_start - mi_col) * MI_SIZE;
10691         x->mv_limits.col_max = (tile->mi_col_end - mi_col) * MI_SIZE - w;
10692         x->mv_limits.row_min = (tile->mi_row_start - mi_row) * MI_SIZE;
10693         x->mv_limits.row_max =
10694             (sb_row * cm->seq_params.mib_size - mi_row) * MI_SIZE - h;
10695         break;
10696       case IBC_MOTION_LEFT:
10697         x->mv_limits.col_min = (tile->mi_col_start - mi_col) * MI_SIZE;
10698         x->mv_limits.col_max =
10699             (sb_col * cm->seq_params.mib_size - mi_col) * MI_SIZE - w;
10700         // TODO(aconverse@google.com): Minimize the overlap between above and
10701         // left areas.
10702         x->mv_limits.row_min = (tile->mi_row_start - mi_row) * MI_SIZE;
10703         int bottom_coded_mi_edge =
10704             AOMMIN((sb_row + 1) * cm->seq_params.mib_size, tile->mi_row_end);
10705         x->mv_limits.row_max = (bottom_coded_mi_edge - mi_row) * MI_SIZE - h;
10706         break;
10707       default: assert(0);
10708     }
10709     assert(x->mv_limits.col_min >= tmp_mv_limits.col_min);
10710     assert(x->mv_limits.col_max <= tmp_mv_limits.col_max);
10711     assert(x->mv_limits.row_min >= tmp_mv_limits.row_min);
10712     assert(x->mv_limits.row_max <= tmp_mv_limits.row_max);
10713     av1_set_mv_search_range(&x->mv_limits, &dv_ref.as_mv);
10714 
10715     if (x->mv_limits.col_max < x->mv_limits.col_min ||
10716         x->mv_limits.row_max < x->mv_limits.row_min) {
10717       x->mv_limits = tmp_mv_limits;
10718       continue;
10719     }
10720 
10721     int step_param = cpi->mv_step_param;
10722     MV mvp_full = dv_ref.as_mv;
10723     mvp_full.col >>= 3;
10724     mvp_full.row >>= 3;
10725     const int sadpb = x->sadperbit16;
10726     int cost_list[5];
10727     const int bestsme = av1_full_pixel_search(
10728         cpi, x, bsize, &mvp_full, step_param, cpi->sf.mv.search_method, 0,
10729         sadpb, cond_cost_list(cpi, cost_list), &dv_ref.as_mv, INT_MAX, 1,
10730         (MI_SIZE * mi_col), (MI_SIZE * mi_row), 1,
10731         &cpi->ss_cfg[SS_CFG_LOOKAHEAD]);
10732 
10733     x->mv_limits = tmp_mv_limits;
10734     if (bestsme == INT_MAX) continue;
10735     mvp_full = x->best_mv.as_mv;
10736     const MV dv = { .row = mvp_full.row * 8, .col = mvp_full.col * 8 };
10737     if (mv_check_bounds(&x->mv_limits, &dv)) continue;
10738     if (!av1_is_dv_valid(dv, cm, xd, mi_row, mi_col, bsize,
10739                          cm->seq_params.mib_size_log2))
10740       continue;
10741 
10742     // DV should not have sub-pel.
10743     assert((dv.col & 7) == 0);
10744     assert((dv.row & 7) == 0);
10745     memset(&mbmi->palette_mode_info, 0, sizeof(mbmi->palette_mode_info));
10746     mbmi->filter_intra_mode_info.use_filter_intra = 0;
10747     mbmi->use_intrabc = 1;
10748     mbmi->mode = DC_PRED;
10749     mbmi->uv_mode = UV_DC_PRED;
10750     mbmi->motion_mode = SIMPLE_TRANSLATION;
10751     mbmi->mv[0].as_mv = dv;
10752     mbmi->interp_filters = av1_broadcast_interp_filter(BILINEAR);
10753     mbmi->skip = 0;
10754     x->skip = 0;
10755     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize, 0,
10756                                   av1_num_planes(cm) - 1);
10757 
10758     int *dvcost[2] = { (int *)&cpi->dv_cost[0][MV_MAX],
10759                        (int *)&cpi->dv_cost[1][MV_MAX] };
10760     // TODO(aconverse@google.com): The full motion field defining discount
10761     // in MV_COST_WEIGHT is too large. Explore other values.
10762     const int rate_mv = av1_mv_bit_cost(&dv, &dv_ref.as_mv, cpi->dv_joint_cost,
10763                                         dvcost, MV_COST_WEIGHT_SUB);
10764     const int rate_mode = x->intrabc_cost[1];
10765     RD_STATS rd_stats_yuv, rd_stats_y, rd_stats_uv;
10766     if (!txfm_search(cpi, NULL, x, bsize, mi_row, mi_col, &rd_stats_yuv,
10767                      &rd_stats_y, &rd_stats_uv, rate_mode + rate_mv, INT64_MAX))
10768       continue;
10769     rd_stats_yuv.rdcost =
10770         RDCOST(x->rdmult, rd_stats_yuv.rate, rd_stats_yuv.dist);
10771     if (rd_stats_yuv.rdcost < best_rd) {
10772       best_rd = rd_stats_yuv.rdcost;
10773       best_mbmi = *mbmi;
10774       best_skip = mbmi->skip;
10775       best_rdstats = rd_stats_yuv;
10776       memcpy(best_blk_skip, x->blk_skip,
10777              sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
10778     }
10779   }
10780   *mbmi = best_mbmi;
10781   *rd_stats = best_rdstats;
10782   x->skip = best_skip;
10783   memcpy(x->blk_skip, best_blk_skip,
10784          sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
10785 #if CONFIG_RD_DEBUG
10786   mbmi->rd_stats = *rd_stats;
10787 #endif
10788   return best_rd;
10789 }
10790 
av1_rd_pick_intra_mode_sb(const AV1_COMP * cpi,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int64_t best_rd)10791 void av1_rd_pick_intra_mode_sb(const AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
10792                                int mi_col, RD_STATS *rd_cost, BLOCK_SIZE bsize,
10793                                PICK_MODE_CONTEXT *ctx, int64_t best_rd) {
10794   const AV1_COMMON *const cm = &cpi->common;
10795   MACROBLOCKD *const xd = &x->e_mbd;
10796   MB_MODE_INFO *const mbmi = xd->mi[0];
10797   const int num_planes = av1_num_planes(cm);
10798   int rate_y = 0, rate_uv = 0, rate_y_tokenonly = 0, rate_uv_tokenonly = 0;
10799   int y_skip = 0, uv_skip = 0;
10800   int64_t dist_y = 0, dist_uv = 0;
10801   TX_SIZE max_uv_tx_size;
10802 
10803   ctx->skip = 0;
10804   mbmi->ref_frame[0] = INTRA_FRAME;
10805   mbmi->ref_frame[1] = NONE_FRAME;
10806   mbmi->use_intrabc = 0;
10807   mbmi->mv[0].as_int = 0;
10808 
10809   const int64_t intra_yrd =
10810       rd_pick_intra_sby_mode(cpi, x, mi_row, mi_col, &rate_y, &rate_y_tokenonly,
10811                              &dist_y, &y_skip, bsize, best_rd, ctx);
10812 
10813   if (intra_yrd < best_rd) {
10814     // Only store reconstructed luma when there's chroma RDO. When there's no
10815     // chroma RDO, the reconstructed luma will be stored in encode_superblock().
10816     xd->cfl.is_chroma_reference =
10817         is_chroma_reference(mi_row, mi_col, bsize, cm->seq_params.subsampling_x,
10818                             cm->seq_params.subsampling_y);
10819     xd->cfl.store_y = store_cfl_required_rdo(cm, x);
10820     if (xd->cfl.store_y) {
10821       // Restore reconstructed luma values.
10822       memcpy(x->blk_skip, ctx->blk_skip,
10823              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
10824       av1_encode_intra_block_plane(cpi, x, bsize, AOM_PLANE_Y,
10825                                    cpi->optimize_seg_arr[mbmi->segment_id],
10826                                    mi_row, mi_col);
10827       xd->cfl.store_y = 0;
10828     }
10829     if (num_planes > 1) {
10830       max_uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
10831       init_sbuv_mode(mbmi);
10832       if (!x->skip_chroma_rd)
10833         rd_pick_intra_sbuv_mode(cpi, x, &rate_uv, &rate_uv_tokenonly, &dist_uv,
10834                                 &uv_skip, bsize, max_uv_tx_size);
10835     }
10836 
10837     if (y_skip && (uv_skip || x->skip_chroma_rd)) {
10838       rd_cost->rate = rate_y + rate_uv - rate_y_tokenonly - rate_uv_tokenonly +
10839                       x->skip_cost[av1_get_skip_context(xd)][1];
10840       rd_cost->dist = dist_y + dist_uv;
10841     } else {
10842       rd_cost->rate =
10843           rate_y + rate_uv + x->skip_cost[av1_get_skip_context(xd)][0];
10844       rd_cost->dist = dist_y + dist_uv;
10845     }
10846     rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
10847   } else {
10848     rd_cost->rate = INT_MAX;
10849   }
10850 
10851   if (rd_cost->rate != INT_MAX && rd_cost->rdcost < best_rd)
10852     best_rd = rd_cost->rdcost;
10853   if (rd_pick_intrabc_mode_sb(cpi, x, rd_cost, bsize, best_rd) < best_rd) {
10854     ctx->skip = x->skip;
10855     memcpy(ctx->blk_skip, x->blk_skip,
10856            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
10857     assert(rd_cost->rate != INT_MAX);
10858   }
10859   if (rd_cost->rate == INT_MAX) return;
10860 
10861   ctx->mic = *xd->mi[0];
10862   ctx->mbmi_ext = *x->mbmi_ext;
10863 }
10864 
restore_uv_color_map(const AV1_COMP * const cpi,MACROBLOCK * x)10865 static void restore_uv_color_map(const AV1_COMP *const cpi, MACROBLOCK *x) {
10866   MACROBLOCKD *const xd = &x->e_mbd;
10867   MB_MODE_INFO *const mbmi = xd->mi[0];
10868   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
10869   const BLOCK_SIZE bsize = mbmi->sb_type;
10870   int src_stride = x->plane[1].src.stride;
10871   const uint8_t *const src_u = x->plane[1].src.buf;
10872   const uint8_t *const src_v = x->plane[2].src.buf;
10873   int *const data = x->palette_buffer->kmeans_data_buf;
10874   int centroids[2 * PALETTE_MAX_SIZE];
10875   uint8_t *const color_map = xd->plane[1].color_index_map;
10876   int r, c;
10877   const uint16_t *const src_u16 = CONVERT_TO_SHORTPTR(src_u);
10878   const uint16_t *const src_v16 = CONVERT_TO_SHORTPTR(src_v);
10879   int plane_block_width, plane_block_height, rows, cols;
10880   av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
10881                            &plane_block_height, &rows, &cols);
10882 
10883   for (r = 0; r < rows; ++r) {
10884     for (c = 0; c < cols; ++c) {
10885       if (cpi->common.seq_params.use_highbitdepth) {
10886         data[(r * cols + c) * 2] = src_u16[r * src_stride + c];
10887         data[(r * cols + c) * 2 + 1] = src_v16[r * src_stride + c];
10888       } else {
10889         data[(r * cols + c) * 2] = src_u[r * src_stride + c];
10890         data[(r * cols + c) * 2 + 1] = src_v[r * src_stride + c];
10891       }
10892     }
10893   }
10894 
10895   for (r = 1; r < 3; ++r) {
10896     for (c = 0; c < pmi->palette_size[1]; ++c) {
10897       centroids[c * 2 + r - 1] = pmi->palette_colors[r * PALETTE_MAX_SIZE + c];
10898     }
10899   }
10900 
10901   av1_calc_indices(data, centroids, color_map, rows * cols,
10902                    pmi->palette_size[1], 2);
10903   extend_palette_color_map(color_map, cols, rows, plane_block_width,
10904                            plane_block_height);
10905 }
10906 
10907 static void calc_target_weighted_pred(const AV1_COMMON *cm, const MACROBLOCK *x,
10908                                       const MACROBLOCKD *xd, int mi_row,
10909                                       int mi_col, const uint8_t *above,
10910                                       int above_stride, const uint8_t *left,
10911                                       int left_stride);
10912 
rd_pick_skip_mode(RD_STATS * rd_cost,InterModeSearchState * search_state,const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mi_row,int mi_col,struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE])10913 static void rd_pick_skip_mode(RD_STATS *rd_cost,
10914                               InterModeSearchState *search_state,
10915                               const AV1_COMP *const cpi, MACROBLOCK *const x,
10916                               BLOCK_SIZE bsize, int mi_row, int mi_col,
10917                               struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE]) {
10918   const AV1_COMMON *const cm = &cpi->common;
10919   const SkipModeInfo *const skip_mode_info = &cm->current_frame.skip_mode_info;
10920   const int num_planes = av1_num_planes(cm);
10921   MACROBLOCKD *const xd = &x->e_mbd;
10922   MB_MODE_INFO *const mbmi = xd->mi[0];
10923 
10924   x->compound_idx = 1;  // COMPOUND_AVERAGE
10925   RD_STATS skip_mode_rd_stats;
10926   av1_invalid_rd_stats(&skip_mode_rd_stats);
10927 
10928   if (skip_mode_info->ref_frame_idx_0 == INVALID_IDX ||
10929       skip_mode_info->ref_frame_idx_1 == INVALID_IDX) {
10930     return;
10931   }
10932 
10933   const MV_REFERENCE_FRAME ref_frame =
10934       LAST_FRAME + skip_mode_info->ref_frame_idx_0;
10935   const MV_REFERENCE_FRAME second_ref_frame =
10936       LAST_FRAME + skip_mode_info->ref_frame_idx_1;
10937   const PREDICTION_MODE this_mode = NEAREST_NEARESTMV;
10938   const int mode_index =
10939       get_prediction_mode_idx(this_mode, ref_frame, second_ref_frame);
10940 
10941   if (mode_index == -1) {
10942     return;
10943   }
10944 
10945   if (!cpi->oxcf.enable_onesided_comp && cpi->all_one_sided_refs) {
10946     return;
10947   }
10948 
10949   mbmi->mode = this_mode;
10950   mbmi->uv_mode = UV_DC_PRED;
10951   mbmi->ref_frame[0] = ref_frame;
10952   mbmi->ref_frame[1] = second_ref_frame;
10953   const uint8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
10954   if (x->mbmi_ext->ref_mv_count[ref_frame_type] == UINT8_MAX) {
10955     if (x->mbmi_ext->ref_mv_count[ref_frame] == UINT8_MAX ||
10956         x->mbmi_ext->ref_mv_count[second_ref_frame] == UINT8_MAX) {
10957       return;
10958     }
10959     MB_MODE_INFO_EXT *mbmi_ext = x->mbmi_ext;
10960     av1_find_mv_refs(cm, xd, mbmi, ref_frame_type, mbmi_ext->ref_mv_count,
10961                      mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
10962                      mi_col, mbmi_ext->mode_context);
10963   }
10964 
10965   assert(this_mode == NEAREST_NEARESTMV);
10966   if (!build_cur_mv(mbmi->mv, this_mode, cm, x)) {
10967     return;
10968   }
10969 
10970   mbmi->filter_intra_mode_info.use_filter_intra = 0;
10971   mbmi->interintra_mode = (INTERINTRA_MODE)(II_DC_PRED - 1);
10972   mbmi->comp_group_idx = 0;
10973   mbmi->compound_idx = x->compound_idx;
10974   mbmi->interinter_comp.type = COMPOUND_AVERAGE;
10975   mbmi->motion_mode = SIMPLE_TRANSLATION;
10976   mbmi->ref_mv_idx = 0;
10977   mbmi->skip_mode = mbmi->skip = 1;
10978 
10979   set_default_interp_filters(mbmi, cm->interp_filter);
10980 
10981   set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
10982   for (int i = 0; i < num_planes; i++) {
10983     xd->plane[i].pre[0] = yv12_mb[mbmi->ref_frame[0]][i];
10984     xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
10985   }
10986 
10987   BUFFER_SET orig_dst;
10988   for (int i = 0; i < num_planes; i++) {
10989     orig_dst.plane[i] = xd->plane[i].dst.buf;
10990     orig_dst.stride[i] = xd->plane[i].dst.stride;
10991   }
10992 
10993   // Obtain the rdcost for skip_mode.
10994   skip_mode_rd(&skip_mode_rd_stats, cpi, x, bsize, mi_row, mi_col, &orig_dst);
10995 
10996   // Compare the use of skip_mode with the best intra/inter mode obtained.
10997   const int skip_mode_ctx = av1_get_skip_mode_context(xd);
10998   const int64_t best_intra_inter_mode_cost =
10999       (rd_cost->dist < INT64_MAX && rd_cost->rate < INT32_MAX)
11000           ? RDCOST(x->rdmult,
11001                    rd_cost->rate + x->skip_mode_cost[skip_mode_ctx][0],
11002                    rd_cost->dist)
11003           : INT64_MAX;
11004 
11005   if (skip_mode_rd_stats.rdcost <= best_intra_inter_mode_cost &&
11006       (!xd->lossless[mbmi->segment_id] || skip_mode_rd_stats.dist == 0)) {
11007     assert(mode_index != -1);
11008     search_state->best_mbmode.skip_mode = 1;
11009     search_state->best_mbmode = *mbmi;
11010 
11011     search_state->best_mbmode.skip_mode = search_state->best_mbmode.skip = 1;
11012     search_state->best_mbmode.mode = NEAREST_NEARESTMV;
11013     search_state->best_mbmode.ref_frame[0] = mbmi->ref_frame[0];
11014     search_state->best_mbmode.ref_frame[1] = mbmi->ref_frame[1];
11015     search_state->best_mbmode.mv[0].as_int = mbmi->mv[0].as_int;
11016     search_state->best_mbmode.mv[1].as_int = mbmi->mv[1].as_int;
11017     search_state->best_mbmode.ref_mv_idx = 0;
11018 
11019     // Set up tx_size related variables for skip-specific loop filtering.
11020     search_state->best_mbmode.tx_size =
11021         block_signals_txsize(bsize) ? tx_size_from_tx_mode(bsize, cm->tx_mode)
11022                                     : max_txsize_rect_lookup[bsize];
11023     memset(search_state->best_mbmode.inter_tx_size,
11024            search_state->best_mbmode.tx_size,
11025            sizeof(search_state->best_mbmode.inter_tx_size));
11026     set_txfm_ctxs(search_state->best_mbmode.tx_size, xd->n4_w, xd->n4_h,
11027                   search_state->best_mbmode.skip && is_inter_block(mbmi), xd);
11028 
11029     // Set up color-related variables for skip mode.
11030     search_state->best_mbmode.uv_mode = UV_DC_PRED;
11031     search_state->best_mbmode.palette_mode_info.palette_size[0] = 0;
11032     search_state->best_mbmode.palette_mode_info.palette_size[1] = 0;
11033 
11034     search_state->best_mbmode.comp_group_idx = 0;
11035     search_state->best_mbmode.compound_idx = x->compound_idx;
11036     search_state->best_mbmode.interinter_comp.type = COMPOUND_AVERAGE;
11037     search_state->best_mbmode.motion_mode = SIMPLE_TRANSLATION;
11038 
11039     search_state->best_mbmode.interintra_mode =
11040         (INTERINTRA_MODE)(II_DC_PRED - 1);
11041     search_state->best_mbmode.filter_intra_mode_info.use_filter_intra = 0;
11042 
11043     set_default_interp_filters(&search_state->best_mbmode, cm->interp_filter);
11044 
11045     search_state->best_mode_index = mode_index;
11046 
11047     // Update rd_cost
11048     rd_cost->rate = skip_mode_rd_stats.rate;
11049     rd_cost->dist = rd_cost->sse = skip_mode_rd_stats.dist;
11050     rd_cost->rdcost = skip_mode_rd_stats.rdcost;
11051 
11052     search_state->best_rd = rd_cost->rdcost;
11053     search_state->best_skip2 = 1;
11054     search_state->best_mode_skippable = 1;
11055 
11056     x->skip = 1;
11057   }
11058 }
11059 
11060 // speed feature: fast intra/inter transform type search
11061 // Used for speed >= 2
11062 // When this speed feature is on, in rd mode search, only DCT is used.
11063 // After the mode is determined, this function is called, to select
11064 // transform types and get accurate rdcost.
sf_refine_fast_tx_type_search(const AV1_COMP * cpi,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int best_mode_index,MB_MODE_INFO * best_mbmode,struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE],int best_rate_y,int best_rate_uv,int * best_skip2)11065 static void sf_refine_fast_tx_type_search(
11066     const AV1_COMP *cpi, MACROBLOCK *x, int mi_row, int mi_col,
11067     RD_STATS *rd_cost, BLOCK_SIZE bsize, PICK_MODE_CONTEXT *ctx,
11068     int best_mode_index, MB_MODE_INFO *best_mbmode,
11069     struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE], int best_rate_y,
11070     int best_rate_uv, int *best_skip2) {
11071   const AV1_COMMON *const cm = &cpi->common;
11072   const SPEED_FEATURES *const sf = &cpi->sf;
11073   MACROBLOCKD *const xd = &x->e_mbd;
11074   MB_MODE_INFO *const mbmi = xd->mi[0];
11075   const int num_planes = av1_num_planes(cm);
11076 
11077   if (xd->lossless[mbmi->segment_id] == 0 && best_mode_index >= 0 &&
11078       ((sf->tx_type_search.fast_inter_tx_type_search == 1 &&
11079         is_inter_mode(best_mbmode->mode)) ||
11080        (sf->tx_type_search.fast_intra_tx_type_search == 1 &&
11081         !is_inter_mode(best_mbmode->mode)))) {
11082     int skip_blk = 0;
11083     RD_STATS rd_stats_y, rd_stats_uv;
11084     const int skip_ctx = av1_get_skip_context(xd);
11085 
11086     x->use_default_inter_tx_type = 0;
11087     x->use_default_intra_tx_type = 0;
11088 
11089     *mbmi = *best_mbmode;
11090 
11091     set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
11092 
11093     // Select prediction reference frames.
11094     for (int i = 0; i < num_planes; i++) {
11095       xd->plane[i].pre[0] = yv12_mb[mbmi->ref_frame[0]][i];
11096       if (has_second_ref(mbmi))
11097         xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
11098     }
11099 
11100     if (is_inter_mode(mbmi->mode)) {
11101       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize, 0,
11102                                     av1_num_planes(cm) - 1);
11103       if (mbmi->motion_mode == OBMC_CAUSAL)
11104         av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
11105 
11106       av1_subtract_plane(x, bsize, 0);
11107       if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
11108         pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col,
11109                               INT64_MAX);
11110         assert(rd_stats_y.rate != INT_MAX);
11111       } else {
11112         super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
11113         memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
11114         for (int i = 0; i < xd->n4_h * xd->n4_w; ++i)
11115           set_blk_skip(x, 0, i, rd_stats_y.skip);
11116       }
11117     } else {
11118       super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
11119     }
11120 
11121     if (num_planes > 1) {
11122       super_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX);
11123     } else {
11124       av1_init_rd_stats(&rd_stats_uv);
11125     }
11126 
11127     if (RDCOST(x->rdmult,
11128                x->skip_cost[skip_ctx][0] + rd_stats_y.rate + rd_stats_uv.rate,
11129                (rd_stats_y.dist + rd_stats_uv.dist)) >
11130         RDCOST(x->rdmult, x->skip_cost[skip_ctx][1],
11131                (rd_stats_y.sse + rd_stats_uv.sse))) {
11132       skip_blk = 1;
11133       rd_stats_y.rate = x->skip_cost[skip_ctx][1];
11134       rd_stats_uv.rate = 0;
11135       rd_stats_y.dist = rd_stats_y.sse;
11136       rd_stats_uv.dist = rd_stats_uv.sse;
11137     } else {
11138       skip_blk = 0;
11139       rd_stats_y.rate += x->skip_cost[skip_ctx][0];
11140     }
11141 
11142     if (RDCOST(x->rdmult, best_rate_y + best_rate_uv, rd_cost->dist) >
11143         RDCOST(x->rdmult, rd_stats_y.rate + rd_stats_uv.rate,
11144                (rd_stats_y.dist + rd_stats_uv.dist))) {
11145       best_mbmode->tx_size = mbmi->tx_size;
11146       av1_copy(best_mbmode->inter_tx_size, mbmi->inter_tx_size);
11147       memcpy(ctx->blk_skip, x->blk_skip,
11148              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
11149       av1_copy(best_mbmode->txk_type, mbmi->txk_type);
11150       rd_cost->rate +=
11151           (rd_stats_y.rate + rd_stats_uv.rate - best_rate_y - best_rate_uv);
11152       rd_cost->dist = rd_stats_y.dist + rd_stats_uv.dist;
11153       rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
11154       *best_skip2 = skip_blk;
11155     }
11156   }
11157 }
11158 
11159 typedef struct {
11160   // Mask for each reference frame, specifying which prediction modes to NOT try
11161   // during search.
11162   uint32_t pred_modes[REF_FRAMES];
11163   // If ref_combo[i][j + 1] is true, do NOT try prediction using combination of
11164   // reference frames (i, j).
11165   // Note: indexing with 'j + 1' is due to the fact that 2nd reference can be -1
11166   // (NONE_FRAME).
11167   bool ref_combo[REF_FRAMES][REF_FRAMES + 1];
11168 } mode_skip_mask_t;
11169 
11170 // Update 'ref_combo' mask to disable given 'ref' in single and compound modes.
disable_reference(MV_REFERENCE_FRAME ref,bool ref_combo[REF_FRAMES][REF_FRAMES+1])11171 static void disable_reference(MV_REFERENCE_FRAME ref,
11172                               bool ref_combo[REF_FRAMES][REF_FRAMES + 1]) {
11173   for (MV_REFERENCE_FRAME ref2 = NONE_FRAME; ref2 < REF_FRAMES; ++ref2) {
11174     ref_combo[ref][ref2 + 1] = true;
11175   }
11176 }
11177 
11178 // Update 'ref_combo' mask to disable all inter references except ALTREF.
disable_inter_references_except_altref(bool ref_combo[REF_FRAMES][REF_FRAMES+1])11179 static void disable_inter_references_except_altref(
11180     bool ref_combo[REF_FRAMES][REF_FRAMES + 1]) {
11181   disable_reference(LAST_FRAME, ref_combo);
11182   disable_reference(LAST2_FRAME, ref_combo);
11183   disable_reference(LAST3_FRAME, ref_combo);
11184   disable_reference(GOLDEN_FRAME, ref_combo);
11185   disable_reference(BWDREF_FRAME, ref_combo);
11186   disable_reference(ALTREF2_FRAME, ref_combo);
11187 }
11188 
11189 static const MV_REFERENCE_FRAME reduced_ref_combos[][2] = {
11190   { LAST_FRAME, NONE_FRAME },     { ALTREF_FRAME, NONE_FRAME },
11191   { LAST_FRAME, ALTREF_FRAME },   { GOLDEN_FRAME, NONE_FRAME },
11192   { INTRA_FRAME, NONE_FRAME },    { GOLDEN_FRAME, ALTREF_FRAME },
11193   { LAST_FRAME, GOLDEN_FRAME },   { LAST_FRAME, INTRA_FRAME },
11194   { LAST_FRAME, BWDREF_FRAME },   { LAST_FRAME, LAST3_FRAME },
11195   { GOLDEN_FRAME, BWDREF_FRAME }, { GOLDEN_FRAME, INTRA_FRAME },
11196   { BWDREF_FRAME, NONE_FRAME },   { BWDREF_FRAME, ALTREF_FRAME },
11197   { ALTREF_FRAME, INTRA_FRAME },  { BWDREF_FRAME, INTRA_FRAME },
11198 };
11199 
11200 static const MV_REFERENCE_FRAME real_time_ref_combos[][2] = {
11201   { LAST_FRAME, NONE_FRAME },
11202   { ALTREF_FRAME, NONE_FRAME },
11203   { GOLDEN_FRAME, NONE_FRAME },
11204   { INTRA_FRAME, NONE_FRAME }
11205 };
11206 
11207 typedef enum { REF_SET_FULL, REF_SET_REDUCED, REF_SET_REALTIME } REF_SET;
11208 
default_skip_mask(mode_skip_mask_t * mask,REF_SET ref_set)11209 static void default_skip_mask(mode_skip_mask_t *mask, REF_SET ref_set) {
11210   if (ref_set == REF_SET_FULL) {
11211     // Everything available by default.
11212     memset(mask, 0, sizeof(*mask));
11213   } else {
11214     // All modes available by default.
11215     memset(mask->pred_modes, 0, sizeof(mask->pred_modes));
11216     // All references disabled first.
11217     for (MV_REFERENCE_FRAME ref1 = INTRA_FRAME; ref1 < REF_FRAMES; ++ref1) {
11218       for (MV_REFERENCE_FRAME ref2 = NONE_FRAME; ref2 < REF_FRAMES; ++ref2) {
11219         mask->ref_combo[ref1][ref2 + 1] = true;
11220       }
11221     }
11222     const MV_REFERENCE_FRAME(*ref_set_combos)[2];
11223     int num_ref_combos;
11224 
11225     // Then enable reduced set of references explicitly.
11226     switch (ref_set) {
11227       case REF_SET_REDUCED:
11228         ref_set_combos = reduced_ref_combos;
11229         num_ref_combos =
11230             (int)sizeof(reduced_ref_combos) / sizeof(reduced_ref_combos[0]);
11231         break;
11232       case REF_SET_REALTIME:
11233         ref_set_combos = real_time_ref_combos;
11234         num_ref_combos =
11235             (int)sizeof(real_time_ref_combos) / sizeof(real_time_ref_combos[0]);
11236         break;
11237       default: assert(0); num_ref_combos = 0;
11238     }
11239 
11240     for (int i = 0; i < num_ref_combos; ++i) {
11241       const MV_REFERENCE_FRAME *const this_combo = ref_set_combos[i];
11242       mask->ref_combo[this_combo[0]][this_combo[1] + 1] = false;
11243     }
11244   }
11245 }
11246 
init_mode_skip_mask(mode_skip_mask_t * mask,const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize)11247 static void init_mode_skip_mask(mode_skip_mask_t *mask, const AV1_COMP *cpi,
11248                                 MACROBLOCK *x, BLOCK_SIZE bsize) {
11249   const AV1_COMMON *const cm = &cpi->common;
11250   const struct segmentation *const seg = &cm->seg;
11251   MACROBLOCKD *const xd = &x->e_mbd;
11252   MB_MODE_INFO *const mbmi = xd->mi[0];
11253   unsigned char segment_id = mbmi->segment_id;
11254   const SPEED_FEATURES *const sf = &cpi->sf;
11255   REF_SET ref_set = REF_SET_FULL;
11256 
11257   if (sf->use_real_time_ref_set)
11258     ref_set = REF_SET_REALTIME;
11259   else if (cpi->oxcf.enable_reduced_reference_set)
11260     ref_set = REF_SET_REDUCED;
11261 
11262   default_skip_mask(mask, ref_set);
11263 
11264   int min_pred_mv_sad = INT_MAX;
11265   MV_REFERENCE_FRAME ref_frame;
11266   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame)
11267     min_pred_mv_sad = AOMMIN(min_pred_mv_sad, x->pred_mv_sad[ref_frame]);
11268 
11269   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
11270     if (!(cpi->ref_frame_flags & av1_ref_frame_flag_list[ref_frame])) {
11271       // Skip checking missing reference in both single and compound reference
11272       // modes.
11273       disable_reference(ref_frame, mask->ref_combo);
11274     } else {
11275       // Skip fixed mv modes for poor references
11276       if ((x->pred_mv_sad[ref_frame] >> 2) > min_pred_mv_sad) {
11277         mask->pred_modes[ref_frame] |= INTER_NEAREST_NEAR_ZERO;
11278       }
11279     }
11280     if (segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME) &&
11281         get_segdata(seg, segment_id, SEG_LVL_REF_FRAME) != (int)ref_frame) {
11282       // Reference not used for the segment.
11283       disable_reference(ref_frame, mask->ref_combo);
11284     }
11285   }
11286   // Note: We use the following drop-out only if the SEG_LVL_REF_FRAME feature
11287   // is disabled for this segment. This is to prevent the possibility that we
11288   // end up unable to pick any mode.
11289   if (!segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME)) {
11290     // Only consider GLOBALMV/ALTREF_FRAME for alt ref frame,
11291     // unless ARNR filtering is enabled in which case we want
11292     // an unfiltered alternative. We allow near/nearest as well
11293     // because they may result in zero-zero MVs but be cheaper.
11294     if (cpi->rc.is_src_frame_alt_ref && (cpi->oxcf.arnr_max_frames == 0)) {
11295       disable_inter_references_except_altref(mask->ref_combo);
11296 
11297       mask->pred_modes[ALTREF_FRAME] = ~INTER_NEAREST_NEAR_ZERO;
11298       const MV_REFERENCE_FRAME tmp_ref_frames[2] = { ALTREF_FRAME, NONE_FRAME };
11299       int_mv near_mv, nearest_mv, global_mv;
11300       get_this_mv(&nearest_mv, NEARESTMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
11301       get_this_mv(&near_mv, NEARMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
11302       get_this_mv(&global_mv, GLOBALMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
11303 
11304       if (near_mv.as_int != global_mv.as_int)
11305         mask->pred_modes[ALTREF_FRAME] |= (1 << NEARMV);
11306       if (nearest_mv.as_int != global_mv.as_int)
11307         mask->pred_modes[ALTREF_FRAME] |= (1 << NEARESTMV);
11308     }
11309   }
11310 
11311   if (cpi->rc.is_src_frame_alt_ref) {
11312     if (sf->alt_ref_search_fp) {
11313       assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
11314       mask->pred_modes[ALTREF_FRAME] = 0;
11315       disable_inter_references_except_altref(mask->ref_combo);
11316       disable_reference(INTRA_FRAME, mask->ref_combo);
11317     }
11318   }
11319 
11320   if (sf->alt_ref_search_fp)
11321     if (!cm->show_frame && x->pred_mv_sad[GOLDEN_FRAME] < INT_MAX)
11322       if (x->pred_mv_sad[ALTREF_FRAME] > (x->pred_mv_sad[GOLDEN_FRAME] << 1))
11323         mask->pred_modes[ALTREF_FRAME] |= INTER_ALL;
11324 
11325   if (sf->adaptive_mode_search) {
11326     if (cm->show_frame && !cpi->rc.is_src_frame_alt_ref &&
11327         cpi->rc.frames_since_golden >= 3)
11328       if ((x->pred_mv_sad[GOLDEN_FRAME] >> 1) > x->pred_mv_sad[LAST_FRAME])
11329         mask->pred_modes[GOLDEN_FRAME] |= INTER_ALL;
11330   }
11331 
11332   if (bsize > sf->max_intra_bsize) {
11333     disable_reference(INTRA_FRAME, mask->ref_combo);
11334   }
11335 
11336   mask->pred_modes[INTRA_FRAME] |=
11337       ~(sf->intra_y_mode_mask[max_txsize_lookup[bsize]]);
11338 }
11339 
11340 // Please add/modify parameter setting in this function, making it consistent
11341 // and easy to read and maintain.
set_params_rd_pick_inter_mode(const AV1_COMP * cpi,MACROBLOCK * x,HandleInterModeArgs * args,BLOCK_SIZE bsize,int mi_row,int mi_col,mode_skip_mask_t * mode_skip_mask,int skip_ref_frame_mask,unsigned int ref_costs_single[REF_FRAMES],unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES],struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE])11342 static void set_params_rd_pick_inter_mode(
11343     const AV1_COMP *cpi, MACROBLOCK *x, HandleInterModeArgs *args,
11344     BLOCK_SIZE bsize, int mi_row, int mi_col, mode_skip_mask_t *mode_skip_mask,
11345     int skip_ref_frame_mask, unsigned int ref_costs_single[REF_FRAMES],
11346     unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES],
11347     struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE]) {
11348   const AV1_COMMON *const cm = &cpi->common;
11349   const int num_planes = av1_num_planes(cm);
11350   MACROBLOCKD *const xd = &x->e_mbd;
11351   MB_MODE_INFO *const mbmi = xd->mi[0];
11352   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
11353   unsigned char segment_id = mbmi->segment_id;
11354   int dst_width1[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
11355   int dst_width2[MAX_MB_PLANE] = { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1,
11356                                    MAX_SB_SIZE >> 1 };
11357   int dst_height1[MAX_MB_PLANE] = { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1,
11358                                     MAX_SB_SIZE >> 1 };
11359   int dst_height2[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
11360 
11361   for (int i = 0; i < MB_MODE_COUNT; ++i)
11362     for (int k = 0; k < REF_FRAMES; ++k) args->single_filter[i][k] = SWITCHABLE;
11363 
11364   if (is_cur_buf_hbd(xd)) {
11365     int len = sizeof(uint16_t);
11366     args->above_pred_buf[0] = CONVERT_TO_BYTEPTR(x->above_pred_buf);
11367     args->above_pred_buf[1] =
11368         CONVERT_TO_BYTEPTR(x->above_pred_buf + (MAX_SB_SQUARE >> 1) * len);
11369     args->above_pred_buf[2] =
11370         CONVERT_TO_BYTEPTR(x->above_pred_buf + MAX_SB_SQUARE * len);
11371     args->left_pred_buf[0] = CONVERT_TO_BYTEPTR(x->left_pred_buf);
11372     args->left_pred_buf[1] =
11373         CONVERT_TO_BYTEPTR(x->left_pred_buf + (MAX_SB_SQUARE >> 1) * len);
11374     args->left_pred_buf[2] =
11375         CONVERT_TO_BYTEPTR(x->left_pred_buf + MAX_SB_SQUARE * len);
11376   } else {
11377     args->above_pred_buf[0] = x->above_pred_buf;
11378     args->above_pred_buf[1] = x->above_pred_buf + (MAX_SB_SQUARE >> 1);
11379     args->above_pred_buf[2] = x->above_pred_buf + MAX_SB_SQUARE;
11380     args->left_pred_buf[0] = x->left_pred_buf;
11381     args->left_pred_buf[1] = x->left_pred_buf + (MAX_SB_SQUARE >> 1);
11382     args->left_pred_buf[2] = x->left_pred_buf + MAX_SB_SQUARE;
11383   }
11384 
11385   av1_collect_neighbors_ref_counts(xd);
11386 
11387   estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
11388                            ref_costs_comp);
11389 
11390   MV_REFERENCE_FRAME ref_frame;
11391   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
11392     x->pred_mv_sad[ref_frame] = INT_MAX;
11393     x->mbmi_ext->mode_context[ref_frame] = 0;
11394     mbmi_ext->ref_mv_count[ref_frame] = UINT8_MAX;
11395     if (cpi->ref_frame_flags & av1_ref_frame_flag_list[ref_frame]) {
11396       if (mbmi->partition != PARTITION_NONE &&
11397           mbmi->partition != PARTITION_SPLIT) {
11398         if (skip_ref_frame_mask & (1 << ref_frame)) {
11399           int skip = 1;
11400           for (int r = ALTREF_FRAME + 1; r < MODE_CTX_REF_FRAMES; ++r) {
11401             if (!(skip_ref_frame_mask & (1 << r))) {
11402               const MV_REFERENCE_FRAME *rf = ref_frame_map[r - REF_FRAMES];
11403               if (rf[0] == ref_frame || rf[1] == ref_frame) {
11404                 skip = 0;
11405                 break;
11406               }
11407             }
11408           }
11409           if (skip) continue;
11410         }
11411       }
11412       assert(get_ref_frame_yv12_buf(cm, ref_frame) != NULL);
11413       setup_buffer_ref_mvs_inter(cpi, x, ref_frame, bsize, mi_row, mi_col,
11414                                  yv12_mb);
11415     }
11416   }
11417   // ref_frame = ALTREF_FRAME
11418   for (; ref_frame < MODE_CTX_REF_FRAMES; ++ref_frame) {
11419     x->mbmi_ext->mode_context[ref_frame] = 0;
11420     mbmi_ext->ref_mv_count[ref_frame] = UINT8_MAX;
11421     const MV_REFERENCE_FRAME *rf = ref_frame_map[ref_frame - REF_FRAMES];
11422     if (!((cpi->ref_frame_flags & av1_ref_frame_flag_list[rf[0]]) &&
11423           (cpi->ref_frame_flags & av1_ref_frame_flag_list[rf[1]]))) {
11424       continue;
11425     }
11426 
11427     if (mbmi->partition != PARTITION_NONE &&
11428         mbmi->partition != PARTITION_SPLIT) {
11429       if (skip_ref_frame_mask & (1 << ref_frame)) {
11430         continue;
11431       }
11432     }
11433     av1_find_mv_refs(cm, xd, mbmi, ref_frame, mbmi_ext->ref_mv_count,
11434                      mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
11435                      mi_col, mbmi_ext->mode_context);
11436   }
11437 
11438   av1_count_overlappable_neighbors(cm, xd, mi_row, mi_col);
11439 
11440   if (check_num_overlappable_neighbors(mbmi) &&
11441       is_motion_variation_allowed_bsize(bsize)) {
11442     av1_build_prediction_by_above_preds(cm, xd, mi_row, mi_col,
11443                                         args->above_pred_buf, dst_width1,
11444                                         dst_height1, args->above_pred_stride);
11445     av1_build_prediction_by_left_preds(cm, xd, mi_row, mi_col,
11446                                        args->left_pred_buf, dst_width2,
11447                                        dst_height2, args->left_pred_stride);
11448     av1_setup_dst_planes(xd->plane, bsize, &cm->cur_frame->buf, mi_row, mi_col,
11449                          0, num_planes);
11450     calc_target_weighted_pred(
11451         cm, x, xd, mi_row, mi_col, args->above_pred_buf[0],
11452         args->above_pred_stride[0], args->left_pred_buf[0],
11453         args->left_pred_stride[0]);
11454   }
11455 
11456   init_mode_skip_mask(mode_skip_mask, cpi, x, bsize);
11457 
11458   if (cpi->sf.tx_type_search.fast_intra_tx_type_search ||
11459       cpi->oxcf.use_intra_default_tx_only)
11460     x->use_default_intra_tx_type = 1;
11461   else
11462     x->use_default_intra_tx_type = 0;
11463 
11464   if (cpi->sf.tx_type_search.fast_inter_tx_type_search)
11465     x->use_default_inter_tx_type = 1;
11466   else
11467     x->use_default_inter_tx_type = 0;
11468   if (cpi->sf.skip_repeat_interpolation_filter_search) {
11469     x->interp_filter_stats_idx[0] = 0;
11470     x->interp_filter_stats_idx[1] = 0;
11471   }
11472   x->comp_rd_stats_idx = 0;
11473 }
11474 
11475 // TODO(kyslov): now this is very similar to set_params_rd_pick_inter_mode
11476 // (except that doesn't set ALTREF parameters)
11477 //               consider passing a flag to select non-rd path (similar to
11478 //               encode_sb_row)
set_params_nonrd_pick_inter_mode(const AV1_COMP * cpi,MACROBLOCK * x,HandleInterModeArgs * args,BLOCK_SIZE bsize,int mi_row,int mi_col,mode_skip_mask_t * mode_skip_mask,int skip_ref_frame_mask,unsigned int ref_costs_single[REF_FRAMES],unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES],struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE])11479 static void set_params_nonrd_pick_inter_mode(
11480     const AV1_COMP *cpi, MACROBLOCK *x, HandleInterModeArgs *args,
11481     BLOCK_SIZE bsize, int mi_row, int mi_col, mode_skip_mask_t *mode_skip_mask,
11482     int skip_ref_frame_mask, unsigned int ref_costs_single[REF_FRAMES],
11483     unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES],
11484     struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE]) {
11485   const AV1_COMMON *const cm = &cpi->common;
11486   const int num_planes = av1_num_planes(cm);
11487   MACROBLOCKD *const xd = &x->e_mbd;
11488   MB_MODE_INFO *const mbmi = xd->mi[0];
11489   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
11490   unsigned char segment_id = mbmi->segment_id;
11491   int dst_width1[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
11492   int dst_width2[MAX_MB_PLANE] = { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1,
11493                                    MAX_SB_SIZE >> 1 };
11494   int dst_height1[MAX_MB_PLANE] = { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1,
11495                                     MAX_SB_SIZE >> 1 };
11496   int dst_height2[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
11497 
11498   for (int i = 0; i < MB_MODE_COUNT; ++i)
11499     for (int k = 0; k < REF_FRAMES; ++k) args->single_filter[i][k] = SWITCHABLE;
11500 
11501   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
11502     int len = sizeof(uint16_t);
11503     args->above_pred_buf[0] = CONVERT_TO_BYTEPTR(x->above_pred_buf);
11504     args->above_pred_buf[1] =
11505         CONVERT_TO_BYTEPTR(x->above_pred_buf + (MAX_SB_SQUARE >> 1) * len);
11506     args->above_pred_buf[2] =
11507         CONVERT_TO_BYTEPTR(x->above_pred_buf + MAX_SB_SQUARE * len);
11508     args->left_pred_buf[0] = CONVERT_TO_BYTEPTR(x->left_pred_buf);
11509     args->left_pred_buf[1] =
11510         CONVERT_TO_BYTEPTR(x->left_pred_buf + (MAX_SB_SQUARE >> 1) * len);
11511     args->left_pred_buf[2] =
11512         CONVERT_TO_BYTEPTR(x->left_pred_buf + MAX_SB_SQUARE * len);
11513   } else {
11514     args->above_pred_buf[0] = x->above_pred_buf;
11515     args->above_pred_buf[1] = x->above_pred_buf + (MAX_SB_SQUARE >> 1);
11516     args->above_pred_buf[2] = x->above_pred_buf + MAX_SB_SQUARE;
11517     args->left_pred_buf[0] = x->left_pred_buf;
11518     args->left_pred_buf[1] = x->left_pred_buf + (MAX_SB_SQUARE >> 1);
11519     args->left_pred_buf[2] = x->left_pred_buf + MAX_SB_SQUARE;
11520   }
11521 
11522   av1_collect_neighbors_ref_counts(xd);
11523 
11524   estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
11525                            ref_costs_comp);
11526 
11527   MV_REFERENCE_FRAME ref_frame;
11528   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
11529     x->pred_mv_sad[ref_frame] = INT_MAX;
11530     x->mbmi_ext->mode_context[ref_frame] = 0;
11531     mbmi_ext->ref_mv_count[ref_frame] = UINT8_MAX;
11532     if (cpi->ref_frame_flags & av1_ref_frame_flag_list[ref_frame]) {
11533       if (mbmi->partition != PARTITION_NONE &&
11534           mbmi->partition != PARTITION_SPLIT) {
11535         if (skip_ref_frame_mask & (1 << ref_frame)) {
11536           int skip = 1;
11537           for (int r = ALTREF_FRAME + 1; r < MODE_CTX_REF_FRAMES; ++r) {
11538             if (!(skip_ref_frame_mask & (1 << r))) {
11539               const MV_REFERENCE_FRAME *rf = ref_frame_map[r - REF_FRAMES];
11540               if (rf[0] == ref_frame || rf[1] == ref_frame) {
11541                 skip = 0;
11542                 break;
11543               }
11544             }
11545           }
11546           if (skip) continue;
11547         }
11548       }
11549       assert(get_ref_frame_yv12_buf(cm, ref_frame) != NULL);
11550       setup_buffer_ref_mvs_inter(cpi, x, ref_frame, bsize, mi_row, mi_col,
11551                                  yv12_mb);
11552     }
11553   }
11554   av1_count_overlappable_neighbors(cm, xd, mi_row, mi_col);
11555 
11556   if (check_num_overlappable_neighbors(mbmi) &&
11557       is_motion_variation_allowed_bsize(bsize)) {
11558     av1_build_prediction_by_above_preds(cm, xd, mi_row, mi_col,
11559                                         args->above_pred_buf, dst_width1,
11560                                         dst_height1, args->above_pred_stride);
11561     av1_build_prediction_by_left_preds(cm, xd, mi_row, mi_col,
11562                                        args->left_pred_buf, dst_width2,
11563                                        dst_height2, args->left_pred_stride);
11564     av1_setup_dst_planes(xd->plane, bsize, &cm->cur_frame->buf, mi_row, mi_col,
11565                          0, num_planes);
11566     calc_target_weighted_pred(
11567         cm, x, xd, mi_row, mi_col, args->above_pred_buf[0],
11568         args->above_pred_stride[0], args->left_pred_buf[0],
11569         args->left_pred_stride[0]);
11570   }
11571   init_mode_skip_mask(mode_skip_mask, cpi, x, bsize);
11572 
11573   if (cpi->sf.tx_type_search.fast_intra_tx_type_search)
11574     x->use_default_intra_tx_type = 1;
11575   else
11576     x->use_default_intra_tx_type = 0;
11577 
11578   if (cpi->sf.tx_type_search.fast_inter_tx_type_search)
11579     x->use_default_inter_tx_type = 1;
11580   else
11581     x->use_default_inter_tx_type = 0;
11582   if (cpi->sf.skip_repeat_interpolation_filter_search) {
11583     x->interp_filter_stats_idx[0] = 0;
11584     x->interp_filter_stats_idx[1] = 0;
11585   }
11586 }
11587 
search_palette_mode(const AV1_COMP * cpi,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,PICK_MODE_CONTEXT * ctx,BLOCK_SIZE bsize,MB_MODE_INFO * const mbmi,PALETTE_MODE_INFO * const pmi,unsigned int * ref_costs_single,InterModeSearchState * search_state)11588 static void search_palette_mode(const AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
11589                                 int mi_col, RD_STATS *rd_cost,
11590                                 PICK_MODE_CONTEXT *ctx, BLOCK_SIZE bsize,
11591                                 MB_MODE_INFO *const mbmi,
11592                                 PALETTE_MODE_INFO *const pmi,
11593                                 unsigned int *ref_costs_single,
11594                                 InterModeSearchState *search_state) {
11595   const AV1_COMMON *const cm = &cpi->common;
11596   const int num_planes = av1_num_planes(cm);
11597   MACROBLOCKD *const xd = &x->e_mbd;
11598   int rate2 = 0;
11599   int64_t distortion2 = 0, best_rd_palette = search_state->best_rd, this_rd,
11600           best_model_rd_palette = INT64_MAX;
11601   int skippable = 0, rate_overhead_palette = 0;
11602   RD_STATS rd_stats_y;
11603   TX_SIZE uv_tx = TX_4X4;
11604   uint8_t *const best_palette_color_map =
11605       x->palette_buffer->best_palette_color_map;
11606   uint8_t *const color_map = xd->plane[0].color_index_map;
11607   MB_MODE_INFO best_mbmi_palette = *mbmi;
11608   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
11609   const int *const intra_mode_cost = x->mbmode_cost[size_group_lookup[bsize]];
11610   const int rows = block_size_high[bsize];
11611   const int cols = block_size_wide[bsize];
11612 
11613   mbmi->mode = DC_PRED;
11614   mbmi->uv_mode = UV_DC_PRED;
11615   mbmi->ref_frame[0] = INTRA_FRAME;
11616   mbmi->ref_frame[1] = NONE_FRAME;
11617   rate_overhead_palette = rd_pick_palette_intra_sby(
11618       cpi, x, bsize, mi_row, mi_col, intra_mode_cost[DC_PRED],
11619       &best_mbmi_palette, best_palette_color_map, &best_rd_palette,
11620       &best_model_rd_palette, NULL, NULL, NULL, NULL, ctx, best_blk_skip);
11621   if (pmi->palette_size[0] == 0) return;
11622 
11623   memcpy(x->blk_skip, best_blk_skip,
11624          sizeof(best_blk_skip[0]) * bsize_to_num_blk(bsize));
11625 
11626   memcpy(color_map, best_palette_color_map,
11627          rows * cols * sizeof(best_palette_color_map[0]));
11628   super_block_yrd(cpi, x, &rd_stats_y, bsize, search_state->best_rd);
11629   if (rd_stats_y.rate == INT_MAX) return;
11630 
11631   skippable = rd_stats_y.skip;
11632   distortion2 = rd_stats_y.dist;
11633   rate2 = rd_stats_y.rate + rate_overhead_palette;
11634   rate2 += ref_costs_single[INTRA_FRAME];
11635   if (num_planes > 1) {
11636     uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
11637     if (search_state->rate_uv_intra[uv_tx] == INT_MAX) {
11638       choose_intra_uv_mode(
11639           cpi, x, bsize, uv_tx, &search_state->rate_uv_intra[uv_tx],
11640           &search_state->rate_uv_tokenonly[uv_tx],
11641           &search_state->dist_uvs[uv_tx], &search_state->skip_uvs[uv_tx],
11642           &search_state->mode_uv[uv_tx]);
11643       search_state->pmi_uv[uv_tx] = *pmi;
11644       search_state->uv_angle_delta[uv_tx] = mbmi->angle_delta[PLANE_TYPE_UV];
11645     }
11646     mbmi->uv_mode = search_state->mode_uv[uv_tx];
11647     pmi->palette_size[1] = search_state->pmi_uv[uv_tx].palette_size[1];
11648     if (pmi->palette_size[1] > 0) {
11649       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
11650              search_state->pmi_uv[uv_tx].palette_colors + PALETTE_MAX_SIZE,
11651              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
11652     }
11653     mbmi->angle_delta[PLANE_TYPE_UV] = search_state->uv_angle_delta[uv_tx];
11654     skippable = skippable && search_state->skip_uvs[uv_tx];
11655     distortion2 += search_state->dist_uvs[uv_tx];
11656     rate2 += search_state->rate_uv_intra[uv_tx];
11657   }
11658 
11659   if (skippable) {
11660     rate2 -= rd_stats_y.rate;
11661     if (num_planes > 1) rate2 -= search_state->rate_uv_tokenonly[uv_tx];
11662     rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
11663   } else {
11664     rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
11665   }
11666   this_rd = RDCOST(x->rdmult, rate2, distortion2);
11667   if (this_rd < search_state->best_rd) {
11668     search_state->best_mode_index = 3;
11669     mbmi->mv[0].as_int = 0;
11670     rd_cost->rate = rate2;
11671     rd_cost->dist = distortion2;
11672     rd_cost->rdcost = this_rd;
11673     search_state->best_rd = this_rd;
11674     search_state->best_mbmode = *mbmi;
11675     search_state->best_skip2 = 0;
11676     search_state->best_mode_skippable = skippable;
11677     memcpy(ctx->blk_skip, x->blk_skip,
11678            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
11679   }
11680 }
11681 
init_inter_mode_search_state(InterModeSearchState * search_state,const AV1_COMP * cpi,const TileDataEnc * tile_data,const MACROBLOCK * x,BLOCK_SIZE bsize,int64_t best_rd_so_far)11682 static void init_inter_mode_search_state(InterModeSearchState *search_state,
11683                                          const AV1_COMP *cpi,
11684                                          const TileDataEnc *tile_data,
11685                                          const MACROBLOCK *x, BLOCK_SIZE bsize,
11686                                          int64_t best_rd_so_far) {
11687   search_state->best_rd = best_rd_so_far;
11688 
11689   av1_zero(search_state->best_mbmode);
11690 
11691   search_state->best_rate_y = INT_MAX;
11692 
11693   search_state->best_rate_uv = INT_MAX;
11694 
11695   search_state->best_mode_skippable = 0;
11696 
11697   search_state->best_skip2 = 0;
11698 
11699   search_state->best_mode_index = -1;
11700 
11701   const MACROBLOCKD *const xd = &x->e_mbd;
11702   const MB_MODE_INFO *const mbmi = xd->mi[0];
11703   const unsigned char segment_id = mbmi->segment_id;
11704 
11705   search_state->skip_intra_modes = 0;
11706 
11707   search_state->num_available_refs = 0;
11708   memset(search_state->dist_refs, -1, sizeof(search_state->dist_refs));
11709   memset(search_state->dist_order_refs, -1,
11710          sizeof(search_state->dist_order_refs));
11711 
11712   for (int i = 0; i <= LAST_NEW_MV_INDEX; ++i)
11713     search_state->mode_threshold[i] = 0;
11714   const int *const rd_threshes = cpi->rd.threshes[segment_id][bsize];
11715   for (int i = LAST_NEW_MV_INDEX + 1; i < MAX_MODES; ++i)
11716     search_state->mode_threshold[i] =
11717         ((int64_t)rd_threshes[i] * tile_data->thresh_freq_fact[bsize][i]) >> 5;
11718 
11719   search_state->best_intra_mode = DC_PRED;
11720   search_state->best_intra_rd = INT64_MAX;
11721 
11722   search_state->angle_stats_ready = 0;
11723   av1_zero(search_state->directional_mode_skip_mask);
11724 
11725   search_state->best_pred_sse = UINT_MAX;
11726 
11727   for (int i = 0; i < TX_SIZES_ALL; i++)
11728     search_state->rate_uv_intra[i] = INT_MAX;
11729 
11730   av1_zero(search_state->pmi_uv);
11731 
11732   for (int i = 0; i < REFERENCE_MODES; ++i)
11733     search_state->best_pred_rd[i] = INT64_MAX;
11734 
11735   av1_zero(search_state->single_newmv);
11736   av1_zero(search_state->single_newmv_rate);
11737   av1_zero(search_state->single_newmv_valid);
11738   for (int i = 0; i < MB_MODE_COUNT; ++i) {
11739     for (int j = 0; j < MAX_REF_MV_SERCH; ++j) {
11740       for (int ref_frame = 0; ref_frame < REF_FRAMES; ++ref_frame) {
11741         search_state->modelled_rd[i][j][ref_frame] = INT64_MAX;
11742         search_state->simple_rd[i][j][ref_frame] = INT64_MAX;
11743       }
11744     }
11745   }
11746 
11747   for (int dir = 0; dir < 2; ++dir) {
11748     for (int mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
11749       for (int ref_frame = 0; ref_frame < FWD_REFS; ++ref_frame) {
11750         SingleInterModeState *state;
11751 
11752         state = &search_state->single_state[dir][mode][ref_frame];
11753         state->ref_frame = NONE_FRAME;
11754         state->rd = INT64_MAX;
11755 
11756         state = &search_state->single_state_modelled[dir][mode][ref_frame];
11757         state->ref_frame = NONE_FRAME;
11758         state->rd = INT64_MAX;
11759       }
11760     }
11761   }
11762   for (int dir = 0; dir < 2; ++dir) {
11763     for (int mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
11764       for (int ref_frame = 0; ref_frame < FWD_REFS; ++ref_frame) {
11765         search_state->single_rd_order[dir][mode][ref_frame] = NONE_FRAME;
11766       }
11767     }
11768   }
11769   av1_zero(search_state->single_state_cnt);
11770   av1_zero(search_state->single_state_modelled_cnt);
11771 }
11772 
mask_says_skip(const mode_skip_mask_t * mode_skip_mask,const MV_REFERENCE_FRAME * ref_frame,const PREDICTION_MODE this_mode)11773 bool mask_says_skip(const mode_skip_mask_t *mode_skip_mask,
11774                     const MV_REFERENCE_FRAME *ref_frame,
11775                     const PREDICTION_MODE this_mode) {
11776   if (mode_skip_mask->pred_modes[ref_frame[0]] & (1 << this_mode)) {
11777     return true;
11778   }
11779 
11780   return mode_skip_mask->ref_combo[ref_frame[0]][ref_frame[1] + 1];
11781 }
11782 
inter_mode_compatible_skip(const AV1_COMP * cpi,const MACROBLOCK * x,BLOCK_SIZE bsize,int mode_index)11783 static int inter_mode_compatible_skip(const AV1_COMP *cpi, const MACROBLOCK *x,
11784                                       BLOCK_SIZE bsize, int mode_index) {
11785   const AV1_COMMON *const cm = &cpi->common;
11786   const struct segmentation *const seg = &cm->seg;
11787   const MV_REFERENCE_FRAME *ref_frame = av1_mode_order[mode_index].ref_frame;
11788   const PREDICTION_MODE this_mode = av1_mode_order[mode_index].mode;
11789   const CurrentFrame *const current_frame = &cm->current_frame;
11790   const MACROBLOCKD *const xd = &x->e_mbd;
11791   const MB_MODE_INFO *const mbmi = xd->mi[0];
11792   const unsigned char segment_id = mbmi->segment_id;
11793   const int comp_pred = ref_frame[1] > INTRA_FRAME;
11794 
11795   if (comp_pred) {
11796     if (frame_is_intra_only(cm)) return 1;
11797 
11798     if (current_frame->reference_mode == SINGLE_REFERENCE) return 1;
11799 
11800     // Skip compound inter modes if ARF is not available.
11801     if (!(cpi->ref_frame_flags & av1_ref_frame_flag_list[ref_frame[1]]))
11802       return 1;
11803 
11804     // Do not allow compound prediction if the segment level reference frame
11805     // feature is in use as in this case there can only be one reference.
11806     if (segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME)) return 1;
11807 
11808     if (!is_comp_ref_allowed(bsize)) return 1;
11809   }
11810 
11811   if (ref_frame[0] > INTRA_FRAME && ref_frame[1] == INTRA_FRAME) {
11812     // Mode must be compatible
11813     if (!is_interintra_allowed_mode(this_mode)) return 1;
11814     if (!is_interintra_allowed_bsize(bsize)) return 1;
11815   }
11816 
11817   return 0;
11818 }
11819 
fetch_picked_ref_frames_mask(const MACROBLOCK * const x,BLOCK_SIZE bsize,int mib_size,int mi_row,int mi_col)11820 static int fetch_picked_ref_frames_mask(const MACROBLOCK *const x,
11821                                         BLOCK_SIZE bsize, int mib_size,
11822                                         int mi_row, int mi_col) {
11823   const int sb_size_mask = mib_size - 1;
11824   const int mi_row_in_sb = mi_row & sb_size_mask;
11825   const int mi_col_in_sb = mi_col & sb_size_mask;
11826   const int mi_w = mi_size_wide[bsize];
11827   const int mi_h = mi_size_high[bsize];
11828   int picked_ref_frames_mask = 0;
11829   for (int i = mi_row_in_sb; i < mi_row_in_sb + mi_h; ++i) {
11830     for (int j = mi_col_in_sb; j < mi_col_in_sb + mi_w; ++j) {
11831       picked_ref_frames_mask |= x->picked_ref_frames_mask[i * 32 + j];
11832     }
11833   }
11834   return picked_ref_frames_mask;
11835 }
11836 
11837 // Case 1: return 0, means don't skip this mode
11838 // Case 2: return 1, means skip this mode completely
11839 // Case 3: return 2, means skip compound only, but still try single motion modes
inter_mode_search_order_independent_skip(const AV1_COMP * cpi,const MACROBLOCK * x,BLOCK_SIZE bsize,int mode_index,int mi_row,int mi_col,mode_skip_mask_t * mode_skip_mask,InterModeSearchState * search_state,int skip_ref_frame_mask)11840 static int inter_mode_search_order_independent_skip(
11841     const AV1_COMP *cpi, const MACROBLOCK *x, BLOCK_SIZE bsize, int mode_index,
11842     int mi_row, int mi_col, mode_skip_mask_t *mode_skip_mask,
11843     InterModeSearchState *search_state, int skip_ref_frame_mask) {
11844   const SPEED_FEATURES *const sf = &cpi->sf;
11845   const AV1_COMMON *const cm = &cpi->common;
11846   const OrderHintInfo *const order_hint_info = &cm->seq_params.order_hint_info;
11847   const CurrentFrame *const current_frame = &cm->current_frame;
11848   const MACROBLOCKD *const xd = &x->e_mbd;
11849   const MB_MODE_INFO *const mbmi = xd->mi[0];
11850   const MV_REFERENCE_FRAME *ref_frame = av1_mode_order[mode_index].ref_frame;
11851   const PREDICTION_MODE this_mode = av1_mode_order[mode_index].mode;
11852   const int comp_pred = ref_frame[1] > INTRA_FRAME;
11853   int skip_motion_mode = 0;
11854 
11855   if (mask_says_skip(mode_skip_mask, ref_frame, this_mode)) {
11856     return 1;
11857   }
11858 
11859   // If no valid mode has been found so far in PARTITION_NONE when finding a
11860   // valid partition is required, do not skip mode.
11861   if (search_state->best_rd == INT64_MAX && mbmi->partition == PARTITION_NONE &&
11862       x->must_find_valid_partition)
11863     return 0;
11864 
11865   if (mbmi->partition != PARTITION_NONE && mbmi->partition != PARTITION_SPLIT) {
11866     const int ref_type = av1_ref_frame_type(ref_frame);
11867     int skip_ref = skip_ref_frame_mask & (1 << ref_type);
11868     if (ref_type <= ALTREF_FRAME && skip_ref) {
11869       // Since the compound ref modes depends on the motion estimation result of
11870       // two single ref modes( best mv of single ref modes as the start point )
11871       // If current single ref mode is marked skip, we need to check if it will
11872       // be used in compound ref modes.
11873       for (int r = ALTREF_FRAME + 1; r < MODE_CTX_REF_FRAMES; ++r) {
11874         if (!(skip_ref_frame_mask & (1 << r))) {
11875           const MV_REFERENCE_FRAME *rf = ref_frame_map[r - REF_FRAMES];
11876           if (rf[0] == ref_type || rf[1] == ref_type) {
11877             // Found a not skipped compound ref mode which contains current
11878             // single ref. So this single ref can't be skipped completly
11879             // Just skip it's motion mode search, still try it's simple
11880             // transition mode.
11881             skip_motion_mode = 1;
11882             skip_ref = 0;
11883             break;
11884           }
11885         }
11886       }
11887     }
11888     if (skip_ref) return 1;
11889   }
11890 
11891   if (cpi->two_pass_partition_search && !x->cb_partition_scan) {
11892     const int mi_width = mi_size_wide[bsize];
11893     const int mi_height = mi_size_high[bsize];
11894     int found = 0;
11895     // Search in the stats table to see if the ref frames have been used in the
11896     // first pass of partition search.
11897     for (int row = mi_row; row < mi_row + mi_width && !found;
11898          row += FIRST_PARTITION_PASS_SAMPLE_REGION) {
11899       for (int col = mi_col; col < mi_col + mi_height && !found;
11900            col += FIRST_PARTITION_PASS_SAMPLE_REGION) {
11901         const int index = av1_first_partition_pass_stats_index(row, col);
11902         const FIRST_PARTITION_PASS_STATS *const stats =
11903             &x->first_partition_pass_stats[index];
11904         if (stats->ref0_counts[ref_frame[0]] &&
11905             (ref_frame[1] < 0 || stats->ref1_counts[ref_frame[1]])) {
11906           found = 1;
11907           break;
11908         }
11909       }
11910     }
11911     if (!found) return 1;
11912   }
11913 
11914   // This is only used in motion vector unit test.
11915   if (cpi->oxcf.motion_vector_unit_test && ref_frame[0] == INTRA_FRAME)
11916     return 1;
11917 
11918   if (ref_frame[0] == INTRA_FRAME) {
11919     if (this_mode != DC_PRED) {
11920       // Disable intra modes other than DC_PRED for blocks with low variance
11921       // Threshold for intra skipping based on source variance
11922       // TODO(debargha): Specialize the threshold for super block sizes
11923       const unsigned int skip_intra_var_thresh = 64;
11924       if ((sf->mode_search_skip_flags & FLAG_SKIP_INTRA_LOWVAR) &&
11925           x->source_variance < skip_intra_var_thresh)
11926         return 1;
11927     }
11928   }
11929 
11930   if (sf->selective_ref_frame) {
11931     if (sf->selective_ref_frame >= 3 || x->cb_partition_scan) {
11932       if (ref_frame[0] == ALTREF2_FRAME || ref_frame[1] == ALTREF2_FRAME)
11933         if (get_relative_dist(
11934                 order_hint_info,
11935                 cm->cur_frame->ref_order_hints[ALTREF2_FRAME - LAST_FRAME],
11936                 current_frame->order_hint) < 0)
11937           return 1;
11938       if (ref_frame[0] == BWDREF_FRAME || ref_frame[1] == BWDREF_FRAME)
11939         if (get_relative_dist(
11940                 order_hint_info,
11941                 cm->cur_frame->ref_order_hints[BWDREF_FRAME - LAST_FRAME],
11942                 current_frame->order_hint) < 0)
11943           return 1;
11944     }
11945 
11946     if (sf->selective_ref_frame >= 2 ||
11947         (sf->selective_ref_frame == 1 && comp_pred)) {
11948       if (ref_frame[0] == LAST3_FRAME || ref_frame[1] == LAST3_FRAME)
11949         if (get_relative_dist(
11950                 order_hint_info,
11951                 cm->cur_frame->ref_order_hints[LAST3_FRAME - LAST_FRAME],
11952                 cm->cur_frame->ref_order_hints[GOLDEN_FRAME - LAST_FRAME]) <= 0)
11953           return 1;
11954       if (ref_frame[0] == LAST2_FRAME || ref_frame[1] == LAST2_FRAME)
11955         if (get_relative_dist(
11956                 order_hint_info,
11957                 cm->cur_frame->ref_order_hints[LAST2_FRAME - LAST_FRAME],
11958                 cm->cur_frame->ref_order_hints[GOLDEN_FRAME - LAST_FRAME]) <= 0)
11959           return 1;
11960     }
11961   }
11962 
11963   // One-sided compound is used only when all reference frames are one-sided.
11964   if ((sf->selective_ref_frame >= 2) && comp_pred && !cpi->all_one_sided_refs) {
11965     unsigned int ref_offsets[2];
11966     for (int i = 0; i < 2; ++i) {
11967       const RefCntBuffer *const buf = get_ref_frame_buf(cm, ref_frame[i]);
11968       assert(buf != NULL);
11969       ref_offsets[i] = buf->order_hint;
11970     }
11971     if ((get_relative_dist(order_hint_info, ref_offsets[0],
11972                            current_frame->order_hint) <= 0 &&
11973          get_relative_dist(order_hint_info, ref_offsets[1],
11974                            current_frame->order_hint) <= 0) ||
11975         (get_relative_dist(order_hint_info, ref_offsets[0],
11976                            current_frame->order_hint) > 0 &&
11977          get_relative_dist(order_hint_info, ref_offsets[1],
11978                            current_frame->order_hint) > 0))
11979       return 1;
11980   }
11981 
11982   if (sf->selective_ref_frame >= 4 && comp_pred) {
11983     // Check if one of the reference is ALTREF2_FRAME and BWDREF_FRAME is a
11984     // valid reference.
11985     if ((ref_frame[0] == ALTREF2_FRAME || ref_frame[1] == ALTREF2_FRAME) &&
11986         (cpi->ref_frame_flags & av1_ref_frame_flag_list[BWDREF_FRAME])) {
11987       // Check if both ALTREF2_FRAME and BWDREF_FRAME are future references.
11988       if ((get_relative_dist(
11989                order_hint_info,
11990                cm->cur_frame->ref_order_hints[ALTREF2_FRAME - LAST_FRAME],
11991                current_frame->order_hint) > 0) &&
11992           (get_relative_dist(
11993                order_hint_info,
11994                cm->cur_frame->ref_order_hints[BWDREF_FRAME - LAST_FRAME],
11995                current_frame->order_hint) > 0)) {
11996         // Drop ALTREF2_FRAME as a reference if BWDREF_FRAME is a closer
11997         // reference to the current frame than ALTREF2_FRAME
11998         if (get_relative_dist(
11999                 order_hint_info,
12000                 cm->cur_frame->ref_order_hints[ALTREF2_FRAME - LAST_FRAME],
12001                 cm->cur_frame->ref_order_hints[BWDREF_FRAME - LAST_FRAME]) >=
12002             0) {
12003           const RefCntBuffer *const buf_arf2 =
12004               get_ref_frame_buf(cm, ALTREF2_FRAME);
12005           assert(buf_arf2 != NULL);
12006           const RefCntBuffer *const buf_bwd =
12007               get_ref_frame_buf(cm, BWDREF_FRAME);
12008           assert(buf_bwd != NULL);
12009           (void)buf_arf2;
12010           (void)buf_bwd;
12011           return 1;
12012         }
12013       }
12014     }
12015   }
12016 
12017   if (skip_repeated_mv(cm, x, this_mode, ref_frame, search_state)) {
12018     return 1;
12019   }
12020   if (skip_motion_mode) {
12021     return 2;
12022   }
12023 
12024   if (!cpi->oxcf.enable_global_motion &&
12025       (this_mode == GLOBALMV || this_mode == GLOBAL_GLOBALMV)) {
12026     return 1;
12027   }
12028 
12029   if (!cpi->oxcf.enable_onesided_comp && comp_pred && cpi->all_one_sided_refs) {
12030     return 1;
12031   }
12032 
12033   return 0;
12034 }
12035 
init_mbmi(MB_MODE_INFO * mbmi,int mode_index,const AV1_COMMON * cm)12036 static INLINE void init_mbmi(MB_MODE_INFO *mbmi, int mode_index,
12037                              const AV1_COMMON *cm) {
12038   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
12039   PREDICTION_MODE this_mode = av1_mode_order[mode_index].mode;
12040   mbmi->ref_mv_idx = 0;
12041   mbmi->mode = this_mode;
12042   mbmi->uv_mode = UV_DC_PRED;
12043   mbmi->ref_frame[0] = av1_mode_order[mode_index].ref_frame[0];
12044   mbmi->ref_frame[1] = av1_mode_order[mode_index].ref_frame[1];
12045   pmi->palette_size[0] = 0;
12046   pmi->palette_size[1] = 0;
12047   mbmi->filter_intra_mode_info.use_filter_intra = 0;
12048   mbmi->mv[0].as_int = mbmi->mv[1].as_int = 0;
12049   mbmi->motion_mode = SIMPLE_TRANSLATION;
12050   mbmi->interintra_mode = (INTERINTRA_MODE)(II_DC_PRED - 1);
12051   set_default_interp_filters(mbmi, cm->interp_filter);
12052 }
12053 
handle_intra_mode(InterModeSearchState * search_state,const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int ref_frame_cost,const PICK_MODE_CONTEXT * ctx,int disable_skip,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv)12054 static int64_t handle_intra_mode(InterModeSearchState *search_state,
12055                                  const AV1_COMP *cpi, MACROBLOCK *x,
12056                                  BLOCK_SIZE bsize, int mi_row, int mi_col,
12057                                  int ref_frame_cost,
12058                                  const PICK_MODE_CONTEXT *ctx, int disable_skip,
12059                                  RD_STATS *rd_stats, RD_STATS *rd_stats_y,
12060                                  RD_STATS *rd_stats_uv) {
12061   const AV1_COMMON *cm = &cpi->common;
12062   const SPEED_FEATURES *const sf = &cpi->sf;
12063   MACROBLOCKD *const xd = &x->e_mbd;
12064   MB_MODE_INFO *const mbmi = xd->mi[0];
12065   assert(mbmi->ref_frame[0] == INTRA_FRAME);
12066   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
12067   const int try_palette =
12068       cpi->oxcf.enable_palette &&
12069       av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
12070   const int *const intra_mode_cost = x->mbmode_cost[size_group_lookup[bsize]];
12071   const int intra_cost_penalty = av1_get_intra_cost_penalty(
12072       cm->base_qindex, cm->y_dc_delta_q, cm->seq_params.bit_depth);
12073   const int rows = block_size_high[bsize];
12074   const int cols = block_size_wide[bsize];
12075   const int num_planes = av1_num_planes(cm);
12076   const int skip_ctx = av1_get_skip_context(xd);
12077 
12078   int known_rate = intra_mode_cost[mbmi->mode];
12079   known_rate += ref_frame_cost;
12080   if (mbmi->mode != DC_PRED && mbmi->mode != PAETH_PRED)
12081     known_rate += intra_cost_penalty;
12082   known_rate += AOMMIN(x->skip_cost[skip_ctx][0], x->skip_cost[skip_ctx][1]);
12083   const int64_t known_rd = RDCOST(x->rdmult, known_rate, 0);
12084   if (known_rd > search_state->best_rd) {
12085     search_state->skip_intra_modes = 1;
12086     return INT64_MAX;
12087   }
12088 
12089   TX_SIZE uv_tx;
12090   int is_directional_mode = av1_is_directional_mode(mbmi->mode);
12091   if (is_directional_mode && av1_use_angle_delta(bsize) &&
12092       cpi->oxcf.enable_angle_delta) {
12093     int rate_dummy;
12094     int64_t model_rd = INT64_MAX;
12095     if (sf->intra_angle_estimation && !search_state->angle_stats_ready) {
12096       const int src_stride = x->plane[0].src.stride;
12097       const uint8_t *src = x->plane[0].src.buf;
12098       angle_estimation(src, src_stride, rows, cols, bsize, is_cur_buf_hbd(xd),
12099                        search_state->directional_mode_skip_mask);
12100       search_state->angle_stats_ready = 1;
12101     }
12102     if (search_state->directional_mode_skip_mask[mbmi->mode]) return INT64_MAX;
12103     av1_init_rd_stats(rd_stats_y);
12104     rd_stats_y->rate = INT_MAX;
12105     rd_pick_intra_angle_sby(cpi, x, mi_row, mi_col, &rate_dummy, rd_stats_y,
12106                             bsize, intra_mode_cost[mbmi->mode],
12107                             search_state->best_rd, &model_rd);
12108   } else {
12109     av1_init_rd_stats(rd_stats_y);
12110     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
12111     super_block_yrd(cpi, x, rd_stats_y, bsize, search_state->best_rd);
12112   }
12113   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
12114   memcpy(best_blk_skip, x->blk_skip,
12115          sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
12116   int try_filter_intra = 0;
12117   int64_t best_rd_tmp = INT64_MAX;
12118   if (mbmi->mode == DC_PRED && av1_filter_intra_allowed_bsize(cm, bsize)) {
12119     if (rd_stats_y->rate != INT_MAX) {
12120       const int tmp_rate = rd_stats_y->rate + x->filter_intra_cost[bsize][0] +
12121                            intra_mode_cost[mbmi->mode];
12122       best_rd_tmp = RDCOST(x->rdmult, tmp_rate, rd_stats_y->dist);
12123       try_filter_intra = !((best_rd_tmp / 2) > search_state->best_rd);
12124     } else {
12125       try_filter_intra = !(search_state->best_mbmode.skip);
12126     }
12127   }
12128   if (try_filter_intra) {
12129     RD_STATS rd_stats_y_fi;
12130     int filter_intra_selected_flag = 0;
12131     TX_SIZE best_tx_size = mbmi->tx_size;
12132     TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
12133     memcpy(best_txk_type, mbmi->txk_type,
12134            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
12135     FILTER_INTRA_MODE best_fi_mode = FILTER_DC_PRED;
12136 
12137     mbmi->filter_intra_mode_info.use_filter_intra = 1;
12138     for (FILTER_INTRA_MODE fi_mode = FILTER_DC_PRED;
12139          fi_mode < FILTER_INTRA_MODES; ++fi_mode) {
12140       int64_t this_rd_tmp;
12141       mbmi->filter_intra_mode_info.filter_intra_mode = fi_mode;
12142       super_block_yrd(cpi, x, &rd_stats_y_fi, bsize, search_state->best_rd);
12143       if (rd_stats_y_fi.rate == INT_MAX) {
12144         continue;
12145       }
12146       const int this_rate_tmp =
12147           rd_stats_y_fi.rate +
12148           intra_mode_info_cost_y(cpi, x, mbmi, bsize,
12149                                  intra_mode_cost[mbmi->mode]);
12150       this_rd_tmp = RDCOST(x->rdmult, this_rate_tmp, rd_stats_y_fi.dist);
12151 
12152       if (this_rd_tmp != INT64_MAX && this_rd_tmp / 2 > search_state->best_rd) {
12153         break;
12154       }
12155       if (this_rd_tmp < best_rd_tmp) {
12156         best_tx_size = mbmi->tx_size;
12157         memcpy(best_txk_type, mbmi->txk_type,
12158                sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
12159         memcpy(best_blk_skip, x->blk_skip,
12160                sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
12161         best_fi_mode = fi_mode;
12162         *rd_stats_y = rd_stats_y_fi;
12163         filter_intra_selected_flag = 1;
12164         best_rd_tmp = this_rd_tmp;
12165       }
12166     }
12167 
12168     mbmi->tx_size = best_tx_size;
12169     memcpy(mbmi->txk_type, best_txk_type,
12170            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
12171     memcpy(x->blk_skip, best_blk_skip,
12172            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
12173 
12174     if (filter_intra_selected_flag) {
12175       mbmi->filter_intra_mode_info.use_filter_intra = 1;
12176       mbmi->filter_intra_mode_info.filter_intra_mode = best_fi_mode;
12177     } else {
12178       mbmi->filter_intra_mode_info.use_filter_intra = 0;
12179     }
12180   }
12181   if (rd_stats_y->rate == INT_MAX) return INT64_MAX;
12182   const int mode_cost_y =
12183       intra_mode_info_cost_y(cpi, x, mbmi, bsize, intra_mode_cost[mbmi->mode]);
12184   av1_init_rd_stats(rd_stats);
12185   av1_init_rd_stats(rd_stats_uv);
12186   if (num_planes > 1) {
12187     uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
12188     if (search_state->rate_uv_intra[uv_tx] == INT_MAX) {
12189       int rate_y =
12190           rd_stats_y->skip ? x->skip_cost[skip_ctx][1] : rd_stats_y->rate;
12191       const int64_t rdy =
12192           RDCOST(x->rdmult, rate_y + mode_cost_y, rd_stats_y->dist);
12193       if (search_state->best_rd < (INT64_MAX / 2) &&
12194           rdy > (search_state->best_rd + (search_state->best_rd >> 2))) {
12195         search_state->skip_intra_modes = 1;
12196         return INT64_MAX;
12197       }
12198       choose_intra_uv_mode(
12199           cpi, x, bsize, uv_tx, &search_state->rate_uv_intra[uv_tx],
12200           &search_state->rate_uv_tokenonly[uv_tx],
12201           &search_state->dist_uvs[uv_tx], &search_state->skip_uvs[uv_tx],
12202           &search_state->mode_uv[uv_tx]);
12203       if (try_palette) search_state->pmi_uv[uv_tx] = *pmi;
12204       search_state->uv_angle_delta[uv_tx] = mbmi->angle_delta[PLANE_TYPE_UV];
12205 
12206       const int uv_rate = search_state->rate_uv_tokenonly[uv_tx];
12207       const int64_t uv_dist = search_state->dist_uvs[uv_tx];
12208       const int64_t uv_rd = RDCOST(x->rdmult, uv_rate, uv_dist);
12209       if (uv_rd > search_state->best_rd) {
12210         search_state->skip_intra_modes = 1;
12211         return INT64_MAX;
12212       }
12213     }
12214 
12215     rd_stats_uv->rate = search_state->rate_uv_tokenonly[uv_tx];
12216     rd_stats_uv->dist = search_state->dist_uvs[uv_tx];
12217     rd_stats_uv->skip = search_state->skip_uvs[uv_tx];
12218     rd_stats->skip = rd_stats_y->skip && rd_stats_uv->skip;
12219     mbmi->uv_mode = search_state->mode_uv[uv_tx];
12220     if (try_palette) {
12221       pmi->palette_size[1] = search_state->pmi_uv[uv_tx].palette_size[1];
12222       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
12223              search_state->pmi_uv[uv_tx].palette_colors + PALETTE_MAX_SIZE,
12224              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
12225     }
12226     mbmi->angle_delta[PLANE_TYPE_UV] = search_state->uv_angle_delta[uv_tx];
12227   }
12228   rd_stats->rate = rd_stats_y->rate + mode_cost_y;
12229   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(bsize)) {
12230     // super_block_yrd above includes the cost of the tx_size in the
12231     // tokenonly rate, but for intra blocks, tx_size is always coded
12232     // (prediction granularity), so we account for it in the full rate,
12233     // not the tokenonly rate.
12234     rd_stats_y->rate -= tx_size_cost(cm, x, bsize, mbmi->tx_size);
12235   }
12236   if (num_planes > 1 && !x->skip_chroma_rd) {
12237     const int uv_mode_cost =
12238         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mbmi->uv_mode];
12239     rd_stats->rate +=
12240         rd_stats_uv->rate +
12241         intra_mode_info_cost_uv(cpi, x, mbmi, bsize, uv_mode_cost);
12242   }
12243   if (mbmi->mode != DC_PRED && mbmi->mode != PAETH_PRED)
12244     rd_stats->rate += intra_cost_penalty;
12245   rd_stats->dist = rd_stats_y->dist + rd_stats_uv->dist;
12246 
12247   // Estimate the reference frame signaling cost and add it
12248   // to the rolling cost variable.
12249   rd_stats->rate += ref_frame_cost;
12250   if (rd_stats->skip) {
12251     // Back out the coefficient coding costs
12252     rd_stats->rate -= (rd_stats_y->rate + rd_stats_uv->rate);
12253     rd_stats_y->rate = 0;
12254     rd_stats_uv->rate = 0;
12255     // Cost the skip mb case
12256     rd_stats->rate += x->skip_cost[skip_ctx][1];
12257   } else {
12258     // Add in the cost of the no skip flag.
12259     rd_stats->rate += x->skip_cost[skip_ctx][0];
12260   }
12261   // Calculate the final RD estimate for this mode.
12262   const int64_t this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
12263   // Keep record of best intra rd
12264   if (this_rd < search_state->best_intra_rd) {
12265     search_state->best_intra_rd = this_rd;
12266     search_state->best_intra_mode = mbmi->mode;
12267   }
12268 
12269   if (sf->skip_intra_in_interframe) {
12270     if (search_state->best_rd < (INT64_MAX / 2) &&
12271         this_rd > (search_state->best_rd + (search_state->best_rd >> 1)))
12272       search_state->skip_intra_modes = 1;
12273   }
12274 
12275   if (!disable_skip) {
12276     for (int i = 0; i < REFERENCE_MODES; ++i)
12277       search_state->best_pred_rd[i] =
12278           AOMMIN(search_state->best_pred_rd[i], this_rd);
12279   }
12280   return this_rd;
12281 }
12282 
collect_single_states(MACROBLOCK * x,InterModeSearchState * search_state,const MB_MODE_INFO * const mbmi)12283 static void collect_single_states(MACROBLOCK *x,
12284                                   InterModeSearchState *search_state,
12285                                   const MB_MODE_INFO *const mbmi) {
12286   int i, j;
12287   const MV_REFERENCE_FRAME ref_frame = mbmi->ref_frame[0];
12288   const PREDICTION_MODE this_mode = mbmi->mode;
12289   const int dir = ref_frame <= GOLDEN_FRAME ? 0 : 1;
12290   const int mode_offset = INTER_OFFSET(this_mode);
12291   const int ref_set = get_drl_refmv_count(x, mbmi->ref_frame, this_mode);
12292 
12293   // Simple rd
12294   int64_t simple_rd = search_state->simple_rd[this_mode][0][ref_frame];
12295   for (int ref_mv_idx = 1; ref_mv_idx < ref_set; ++ref_mv_idx) {
12296     int64_t rd = search_state->simple_rd[this_mode][ref_mv_idx][ref_frame];
12297     if (rd < simple_rd) simple_rd = rd;
12298   }
12299 
12300   // Insertion sort of single_state
12301   SingleInterModeState this_state_s = { simple_rd, ref_frame, 1 };
12302   SingleInterModeState *state_s = search_state->single_state[dir][mode_offset];
12303   i = search_state->single_state_cnt[dir][mode_offset];
12304   for (j = i; j > 0 && state_s[j - 1].rd > this_state_s.rd; --j)
12305     state_s[j] = state_s[j - 1];
12306   state_s[j] = this_state_s;
12307   search_state->single_state_cnt[dir][mode_offset]++;
12308 
12309   // Modelled rd
12310   int64_t modelled_rd = search_state->modelled_rd[this_mode][0][ref_frame];
12311   for (int ref_mv_idx = 1; ref_mv_idx < ref_set; ++ref_mv_idx) {
12312     int64_t rd = search_state->modelled_rd[this_mode][ref_mv_idx][ref_frame];
12313     if (rd < modelled_rd) modelled_rd = rd;
12314   }
12315 
12316   // Insertion sort of single_state_modelled
12317   SingleInterModeState this_state_m = { modelled_rd, ref_frame, 1 };
12318   SingleInterModeState *state_m =
12319       search_state->single_state_modelled[dir][mode_offset];
12320   i = search_state->single_state_modelled_cnt[dir][mode_offset];
12321   for (j = i; j > 0 && state_m[j - 1].rd > this_state_m.rd; --j)
12322     state_m[j] = state_m[j - 1];
12323   state_m[j] = this_state_m;
12324   search_state->single_state_modelled_cnt[dir][mode_offset]++;
12325 }
12326 
analyze_single_states(const AV1_COMP * cpi,InterModeSearchState * search_state)12327 static void analyze_single_states(const AV1_COMP *cpi,
12328                                   InterModeSearchState *search_state) {
12329   int i, j, dir, mode;
12330   if (cpi->sf.prune_comp_search_by_single_result >= 1) {
12331     for (dir = 0; dir < 2; ++dir) {
12332       int64_t best_rd;
12333       SingleInterModeState(*state)[FWD_REFS];
12334       const int prune_factor =
12335           cpi->sf.prune_comp_search_by_single_result >= 2 ? 6 : 5;
12336 
12337       // Use the best rd of GLOBALMV or NEWMV to prune the unlikely
12338       // reference frames for all the modes (NEARESTMV and NEARMV may not
12339       // have same motion vectors). Always keep the best of each mode
12340       // because it might form the best possible combination with other mode.
12341       state = search_state->single_state[dir];
12342       best_rd = AOMMIN(state[INTER_OFFSET(NEWMV)][0].rd,
12343                        state[INTER_OFFSET(GLOBALMV)][0].rd);
12344       for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
12345         for (i = 1; i < search_state->single_state_cnt[dir][mode]; ++i) {
12346           if (state[mode][i].rd != INT64_MAX &&
12347               (state[mode][i].rd >> 3) * prune_factor > best_rd) {
12348             state[mode][i].valid = 0;
12349           }
12350         }
12351       }
12352 
12353       state = search_state->single_state_modelled[dir];
12354       best_rd = AOMMIN(state[INTER_OFFSET(NEWMV)][0].rd,
12355                        state[INTER_OFFSET(GLOBALMV)][0].rd);
12356       for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
12357         for (i = 1; i < search_state->single_state_modelled_cnt[dir][mode];
12358              ++i) {
12359           if (state[mode][i].rd != INT64_MAX &&
12360               (state[mode][i].rd >> 3) * prune_factor > best_rd) {
12361             state[mode][i].valid = 0;
12362           }
12363         }
12364       }
12365     }
12366   }
12367 
12368   // Ordering by simple rd first, then by modelled rd
12369   for (dir = 0; dir < 2; ++dir) {
12370     for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
12371       const int state_cnt_s = search_state->single_state_cnt[dir][mode];
12372       const int state_cnt_m =
12373           search_state->single_state_modelled_cnt[dir][mode];
12374       SingleInterModeState *state_s = search_state->single_state[dir][mode];
12375       SingleInterModeState *state_m =
12376           search_state->single_state_modelled[dir][mode];
12377       int count = 0;
12378       const int max_candidates = AOMMAX(state_cnt_s, state_cnt_m);
12379       for (i = 0; i < state_cnt_s; ++i) {
12380         if (state_s[i].rd == INT64_MAX) break;
12381         if (state_s[i].valid)
12382           search_state->single_rd_order[dir][mode][count++] =
12383               state_s[i].ref_frame;
12384       }
12385       if (count < max_candidates) {
12386         for (i = 0; i < state_cnt_m; ++i) {
12387           if (state_m[i].rd == INT64_MAX) break;
12388           if (state_m[i].valid) {
12389             int ref_frame = state_m[i].ref_frame;
12390             int match = 0;
12391             // Check if existing already
12392             for (j = 0; j < count; ++j) {
12393               if (search_state->single_rd_order[dir][mode][j] == ref_frame) {
12394                 match = 1;
12395                 break;
12396               }
12397             }
12398             if (!match) {
12399               // Check if this ref_frame is removed in simple rd
12400               int valid = 1;
12401               for (j = 0; j < state_cnt_s; j++) {
12402                 if (ref_frame == state_s[j].ref_frame && !state_s[j].valid) {
12403                   valid = 0;
12404                   break;
12405                 }
12406               }
12407               if (valid)
12408                 search_state->single_rd_order[dir][mode][count++] = ref_frame;
12409             }
12410             if (count >= max_candidates) break;
12411           }
12412         }
12413       }
12414     }
12415   }
12416 }
12417 
compound_skip_get_candidates(const AV1_COMP * cpi,const InterModeSearchState * search_state,const int dir,const PREDICTION_MODE mode)12418 static int compound_skip_get_candidates(
12419     const AV1_COMP *cpi, const InterModeSearchState *search_state,
12420     const int dir, const PREDICTION_MODE mode) {
12421   const int mode_offset = INTER_OFFSET(mode);
12422   const SingleInterModeState *state =
12423       search_state->single_state[dir][mode_offset];
12424   const SingleInterModeState *state_modelled =
12425       search_state->single_state_modelled[dir][mode_offset];
12426   int max_candidates = 0;
12427   int candidates;
12428 
12429   for (int i = 0; i < FWD_REFS; ++i) {
12430     if (search_state->single_rd_order[dir][mode_offset][i] == NONE_FRAME) break;
12431     max_candidates++;
12432   }
12433 
12434   candidates = max_candidates;
12435   if (cpi->sf.prune_comp_search_by_single_result >= 2) {
12436     candidates = AOMMIN(2, max_candidates);
12437   }
12438   if (cpi->sf.prune_comp_search_by_single_result >= 3) {
12439     if (state[0].rd != INT64_MAX && state_modelled[0].rd != INT64_MAX &&
12440         state[0].ref_frame == state_modelled[0].ref_frame)
12441       candidates = 1;
12442     if (mode == NEARMV || mode == GLOBALMV) candidates = 1;
12443   }
12444   return candidates;
12445 }
12446 
compound_skip_by_single_states(const AV1_COMP * cpi,const InterModeSearchState * search_state,const PREDICTION_MODE this_mode,const MV_REFERENCE_FRAME ref_frame,const MV_REFERENCE_FRAME second_ref_frame,const MACROBLOCK * x)12447 static int compound_skip_by_single_states(
12448     const AV1_COMP *cpi, const InterModeSearchState *search_state,
12449     const PREDICTION_MODE this_mode, const MV_REFERENCE_FRAME ref_frame,
12450     const MV_REFERENCE_FRAME second_ref_frame, const MACROBLOCK *x) {
12451   const MV_REFERENCE_FRAME refs[2] = { ref_frame, second_ref_frame };
12452   const int mode[2] = { compound_ref0_mode(this_mode),
12453                         compound_ref1_mode(this_mode) };
12454   const int mode_offset[2] = { INTER_OFFSET(mode[0]), INTER_OFFSET(mode[1]) };
12455   const int mode_dir[2] = { refs[0] <= GOLDEN_FRAME ? 0 : 1,
12456                             refs[1] <= GOLDEN_FRAME ? 0 : 1 };
12457   int ref_searched[2] = { 0, 0 };
12458   int ref_mv_match[2] = { 1, 1 };
12459   int i, j;
12460 
12461   for (i = 0; i < 2; ++i) {
12462     const SingleInterModeState *state =
12463         search_state->single_state[mode_dir[i]][mode_offset[i]];
12464     const int state_cnt =
12465         search_state->single_state_cnt[mode_dir[i]][mode_offset[i]];
12466     for (j = 0; j < state_cnt; ++j) {
12467       if (state[j].ref_frame == refs[i]) {
12468         ref_searched[i] = 1;
12469         break;
12470       }
12471     }
12472   }
12473 
12474   const int ref_set = get_drl_refmv_count(x, refs, this_mode);
12475   for (i = 0; i < 2; ++i) {
12476     if (mode[i] == NEARESTMV || mode[i] == NEARMV) {
12477       const MV_REFERENCE_FRAME single_refs[2] = { refs[i], NONE_FRAME };
12478       int idential = 1;
12479       for (int ref_mv_idx = 0; ref_mv_idx < ref_set; ref_mv_idx++) {
12480         int_mv single_mv;
12481         int_mv comp_mv;
12482         get_this_mv(&single_mv, mode[i], 0, ref_mv_idx, single_refs,
12483                     x->mbmi_ext);
12484         get_this_mv(&comp_mv, this_mode, i, ref_mv_idx, refs, x->mbmi_ext);
12485 
12486         idential &= (single_mv.as_int == comp_mv.as_int);
12487         if (!idential) {
12488           ref_mv_match[i] = 0;
12489           break;
12490         }
12491       }
12492     }
12493   }
12494 
12495   for (i = 0; i < 2; ++i) {
12496     if (ref_searched[i] && ref_mv_match[i]) {
12497       const int candidates =
12498           compound_skip_get_candidates(cpi, search_state, mode_dir[i], mode[i]);
12499       const MV_REFERENCE_FRAME *ref_order =
12500           search_state->single_rd_order[mode_dir[i]][mode_offset[i]];
12501       int match = 0;
12502       for (j = 0; j < candidates; ++j) {
12503         if (refs[i] == ref_order[j]) {
12504           match = 1;
12505           break;
12506         }
12507       }
12508       if (!match) return 1;
12509     }
12510   }
12511 
12512   return 0;
12513 }
12514 
sf_check_is_drop_ref(const MODE_DEFINITION * mode,InterModeSearchState * search_state)12515 static INLINE int sf_check_is_drop_ref(const MODE_DEFINITION *mode,
12516                                        InterModeSearchState *search_state) {
12517   const MV_REFERENCE_FRAME ref_frame = mode->ref_frame[0];
12518   const MV_REFERENCE_FRAME second_ref_frame = mode->ref_frame[1];
12519   if (search_state->num_available_refs > 2) {
12520     if ((ref_frame == search_state->dist_order_refs[0] &&
12521          second_ref_frame == search_state->dist_order_refs[1]) ||
12522         (ref_frame == search_state->dist_order_refs[1] &&
12523          second_ref_frame == search_state->dist_order_refs[0]))
12524       return 1;  // drop this pair of refs
12525   }
12526   return 0;
12527 }
12528 
sf_drop_ref_analyze(InterModeSearchState * search_state,const MODE_DEFINITION * mode,int64_t distortion2)12529 static INLINE void sf_drop_ref_analyze(InterModeSearchState *search_state,
12530                                        const MODE_DEFINITION *mode,
12531                                        int64_t distortion2) {
12532   const PREDICTION_MODE this_mode = mode->mode;
12533   MV_REFERENCE_FRAME ref_frame = mode->ref_frame[0];
12534   const int idx = ref_frame - LAST_FRAME;
12535   if (idx && distortion2 > search_state->dist_refs[idx]) {
12536     search_state->dist_refs[idx] = distortion2;
12537     search_state->dist_order_refs[idx] = ref_frame;
12538   }
12539 
12540   // Reach the last single ref prediction mode
12541   if (ref_frame == ALTREF_FRAME && this_mode == GLOBALMV) {
12542     // bubble sort dist_refs and the order index
12543     for (int i = 0; i < REF_FRAMES; ++i) {
12544       for (int k = i + 1; k < REF_FRAMES; ++k) {
12545         if (search_state->dist_refs[i] < search_state->dist_refs[k]) {
12546           int64_t tmp_dist = search_state->dist_refs[i];
12547           search_state->dist_refs[i] = search_state->dist_refs[k];
12548           search_state->dist_refs[k] = tmp_dist;
12549 
12550           int tmp_idx = search_state->dist_order_refs[i];
12551           search_state->dist_order_refs[i] = search_state->dist_order_refs[k];
12552           search_state->dist_order_refs[k] = tmp_idx;
12553         }
12554       }
12555     }
12556     for (int i = 0; i < REF_FRAMES; ++i) {
12557       if (search_state->dist_refs[i] == -1) break;
12558       search_state->num_available_refs = i;
12559     }
12560     search_state->num_available_refs++;
12561   }
12562 }
12563 
12564 // sf->prune_single_motion_modes_by_simple_trans
analyze_simple_trans_states(const AV1_COMP * cpi,MACROBLOCK * x)12565 static int analyze_simple_trans_states(const AV1_COMP *cpi, MACROBLOCK *x) {
12566   (void)cpi;
12567   int64_t rdcosts[REF_FRAMES] = { INT64_MAX, INT64_MAX, INT64_MAX, INT64_MAX,
12568                                   INT64_MAX, INT64_MAX, INT64_MAX, INT64_MAX };
12569   int skip_ref = 0;
12570   int64_t min_rd = INT64_MAX;
12571   for (int i = 0; i < SINGLE_REF_MODES; ++i) {
12572     const MODE_DEFINITION *mode_order = &av1_mode_order[i];
12573     const MV_REFERENCE_FRAME ref_frame = mode_order->ref_frame[0];
12574     for (int k = 0; k < MAX_REF_MV_SERCH; ++k) {
12575       const int64_t rd = x->simple_rd_state[i][k].rd_stats.rdcost;
12576       rdcosts[ref_frame] = AOMMIN(rdcosts[ref_frame], rd);
12577       min_rd = AOMMIN(min_rd, rd);
12578     }
12579   }
12580   int valid_cnt = 0;
12581   for (int i = 1; i < REF_FRAMES; ++i) {
12582     if (rdcosts[i] == INT64_MAX) {
12583       skip_ref |= (1 << i);
12584     } else {
12585       valid_cnt++;
12586     }
12587   }
12588   if (valid_cnt < 2) {
12589     return 0;
12590   }
12591   min_rd += (min_rd >> 1);
12592   if (valid_cnt > 2) {
12593     for (int i = 1; i < REF_FRAMES; ++i) {
12594       if (rdcosts[i] > min_rd) {
12595         skip_ref |= (1 << i);
12596       }
12597     }
12598   }
12599   return skip_ref;
12600 }
12601 
alloc_compound_type_rd_buffers(AV1_COMMON * const cm,CompoundTypeRdBuffers * const bufs)12602 static void alloc_compound_type_rd_buffers(AV1_COMMON *const cm,
12603                                            CompoundTypeRdBuffers *const bufs) {
12604   CHECK_MEM_ERROR(
12605       cm, bufs->pred0,
12606       (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred0)));
12607   CHECK_MEM_ERROR(
12608       cm, bufs->pred1,
12609       (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred1)));
12610   CHECK_MEM_ERROR(
12611       cm, bufs->residual1,
12612       (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->residual1)));
12613   CHECK_MEM_ERROR(
12614       cm, bufs->diff10,
12615       (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->diff10)));
12616   CHECK_MEM_ERROR(cm, bufs->tmp_best_mask_buf,
12617                   (uint8_t *)aom_malloc(2 * MAX_SB_SQUARE *
12618                                         sizeof(*bufs->tmp_best_mask_buf)));
12619 }
12620 
release_compound_type_rd_buffers(CompoundTypeRdBuffers * const bufs)12621 static void release_compound_type_rd_buffers(
12622     CompoundTypeRdBuffers *const bufs) {
12623   aom_free(bufs->pred0);
12624   aom_free(bufs->pred1);
12625   aom_free(bufs->residual1);
12626   aom_free(bufs->diff10);
12627   aom_free(bufs->tmp_best_mask_buf);
12628   av1_zero(*bufs);  // Set all pointers to NULL for safety.
12629 }
12630 
12631 // Enables do_tx_search on a per-mode basis.
do_tx_search_mode(int do_tx_search_global,int midx,int adaptive)12632 int do_tx_search_mode(int do_tx_search_global, int midx, int adaptive) {
12633   if (!adaptive || do_tx_search_global) {
12634     return do_tx_search_global;
12635   }
12636   // A value of 2 indicates it is being turned on conditionally
12637   // for the mode. Turn it on for the first 7 modes.
12638   return midx < 7 ? 2 : 0;
12639 }
12640 
av1_rd_pick_inter_mode_sb(AV1_COMP * cpi,TileDataEnc * tile_data,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int64_t best_rd_so_far)12641 void av1_rd_pick_inter_mode_sb(AV1_COMP *cpi, TileDataEnc *tile_data,
12642                                MACROBLOCK *x, int mi_row, int mi_col,
12643                                RD_STATS *rd_cost, BLOCK_SIZE bsize,
12644                                PICK_MODE_CONTEXT *ctx, int64_t best_rd_so_far) {
12645   AV1_COMMON *const cm = &cpi->common;
12646   const int num_planes = av1_num_planes(cm);
12647   const SPEED_FEATURES *const sf = &cpi->sf;
12648   MACROBLOCKD *const xd = &x->e_mbd;
12649   MB_MODE_INFO *const mbmi = xd->mi[0];
12650   const int try_palette =
12651       cpi->oxcf.enable_palette &&
12652       av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
12653   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
12654   const struct segmentation *const seg = &cm->seg;
12655   PREDICTION_MODE this_mode;
12656   unsigned char segment_id = mbmi->segment_id;
12657   int i;
12658   struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE];
12659   unsigned int ref_costs_single[REF_FRAMES];
12660   unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES];
12661   int *comp_inter_cost = x->comp_inter_cost[av1_get_reference_mode_context(xd)];
12662   mode_skip_mask_t mode_skip_mask;
12663   uint8_t motion_mode_skip_mask = 0;  // second pass of single ref modes
12664 
12665   InterModeSearchState search_state;
12666   init_inter_mode_search_state(&search_state, cpi, tile_data, x, bsize,
12667                                best_rd_so_far);
12668   INTERINTRA_MODE interintra_modes[REF_FRAMES] = {
12669     INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES,
12670     INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES
12671   };
12672   HandleInterModeArgs args = {
12673     { NULL },  { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE },
12674     { NULL },  { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1 },
12675     NULL,      NULL,
12676     NULL,      search_state.modelled_rd,
12677     { { 0 } }, INT_MAX,
12678     INT_MAX,   search_state.simple_rd,
12679     0,         interintra_modes,
12680     1,         NULL
12681   };
12682   for (i = 0; i < REF_FRAMES; ++i) x->pred_sse[i] = INT_MAX;
12683 
12684   av1_invalid_rd_stats(rd_cost);
12685 
12686   // Ref frames that are selected by square partition blocks.
12687   int picked_ref_frames_mask = 0;
12688   if (cpi->sf.prune_ref_frame_for_rect_partitions &&
12689       mbmi->partition != PARTITION_NONE && mbmi->partition != PARTITION_SPLIT) {
12690     // prune_ref_frame_for_rect_partitions = 1 implies prune only extended
12691     // partition blocks. prune_ref_frame_for_rect_partitions >=2
12692     // implies prune for vert, horiz and extended partition blocks.
12693     if ((mbmi->partition != PARTITION_VERT &&
12694          mbmi->partition != PARTITION_HORZ) ||
12695         cpi->sf.prune_ref_frame_for_rect_partitions >= 2) {
12696       picked_ref_frames_mask = fetch_picked_ref_frames_mask(
12697           x, bsize, cm->seq_params.mib_size, mi_row, mi_col);
12698     }
12699   }
12700 
12701   // Skip ref frames that never selected by square blocks.
12702   const int skip_ref_frame_mask =
12703       picked_ref_frames_mask ? ~picked_ref_frames_mask : 0;
12704 
12705   // init params, set frame modes, speed features
12706   set_params_rd_pick_inter_mode(cpi, x, &args, bsize, mi_row, mi_col,
12707                                 &mode_skip_mask, skip_ref_frame_mask,
12708                                 ref_costs_single, ref_costs_comp, yv12_mb);
12709 
12710   int64_t best_est_rd = INT64_MAX;
12711   // TODO(angiebird): Turn this on when this speed feature is well tested
12712   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
12713   // If do_tx_search_global is 0, only estimated RD should be computed.
12714   // If do_tx_search_global is 1, all modes have TX search performed.
12715   // If do_tx_search_global is 2, some modes will have TX search performed.
12716   const int do_tx_search_global =
12717       !((cpi->sf.inter_mode_rd_model_estimation == 1 && md->ready) ||
12718         (cpi->sf.inter_mode_rd_model_estimation == 2 &&
12719          x->source_variance < 512));
12720   InterModesInfo *inter_modes_info = x->inter_modes_info;
12721   inter_modes_info->num = 0;
12722 
12723   int intra_mode_num = 0;
12724   int intra_mode_idx_ls[MAX_MODES];
12725   int reach_first_comp_mode = 0;
12726 
12727   // Temporary buffers used by handle_inter_mode().
12728   uint8_t *const tmp_buf = get_buf_by_bd(xd, x->tmp_obmc_bufs[0]);
12729 
12730   CompoundTypeRdBuffers rd_buffers;
12731   alloc_compound_type_rd_buffers(cm, &rd_buffers);
12732 
12733   for (int midx = 0; midx < MAX_MODES; ++midx) {
12734     const int do_tx_search = do_tx_search_mode(
12735         do_tx_search_global, midx, sf->inter_mode_rd_model_estimation_adaptive);
12736     const MODE_DEFINITION *mode_order = &av1_mode_order[midx];
12737     this_mode = mode_order->mode;
12738     const MV_REFERENCE_FRAME ref_frame = mode_order->ref_frame[0];
12739     const MV_REFERENCE_FRAME second_ref_frame = mode_order->ref_frame[1];
12740     const int comp_pred = second_ref_frame > INTRA_FRAME;
12741 
12742     // When single ref motion search ends:
12743     // 1st pass: To evaluate single ref RD results and rewind to the beginning;
12744     // 2nd pass: To continue with compound ref search.
12745     if (sf->prune_single_motion_modes_by_simple_trans) {
12746       if (comp_pred && args.single_ref_first_pass) {
12747         args.single_ref_first_pass = 0;
12748         // Reach the first comp ref mode
12749         // Reset midx to start the 2nd pass for single ref motion search
12750         midx = -1;
12751         motion_mode_skip_mask = analyze_simple_trans_states(cpi, x);
12752         continue;
12753       }
12754       if (!comp_pred) {  // single ref mode
12755         if (args.single_ref_first_pass) {
12756           // clear stats
12757           for (int k = 0; k < MAX_REF_MV_SERCH; ++k) {
12758             x->simple_rd_state[midx][k].rd_stats.rdcost = INT64_MAX;
12759             x->simple_rd_state[midx][k].early_skipped = 0;
12760           }
12761         } else {
12762           if (motion_mode_skip_mask & (1 << ref_frame)) {
12763             continue;
12764           }
12765         }
12766       }
12767     }
12768 
12769     // Reach the first compound prediction mode
12770     if (sf->prune_comp_search_by_single_result > 0 && comp_pred &&
12771         reach_first_comp_mode == 0) {
12772       analyze_single_states(cpi, &search_state);
12773       reach_first_comp_mode = 1;
12774     }
12775     int64_t this_rd = INT64_MAX;
12776     int disable_skip = 0;
12777     int rate2 = 0, rate_y = 0, rate_uv = 0;
12778     int64_t distortion2 = 0;
12779     int skippable = 0;
12780     int this_skip2 = 0;
12781 
12782     init_mbmi(mbmi, midx, cm);
12783 
12784     x->skip = 0;
12785     set_ref_ptrs(cm, xd, ref_frame, second_ref_frame);
12786 
12787     if (inter_mode_compatible_skip(cpi, x, bsize, midx)) continue;
12788 
12789     const int ret = inter_mode_search_order_independent_skip(
12790         cpi, x, bsize, midx, mi_row, mi_col, &mode_skip_mask, &search_state,
12791         skip_ref_frame_mask);
12792     if (ret == 1) continue;
12793     args.skip_motion_mode = (ret == 2);
12794 
12795     if (sf->drop_ref && comp_pred) {
12796       if (sf_check_is_drop_ref(mode_order, &search_state)) {
12797         continue;
12798       }
12799     }
12800 
12801     if (search_state.best_rd < search_state.mode_threshold[midx]) continue;
12802 
12803     if (sf->prune_comp_search_by_single_result > 0 && comp_pred) {
12804       if (compound_skip_by_single_states(cpi, &search_state, this_mode,
12805                                          ref_frame, second_ref_frame, x))
12806         continue;
12807     }
12808 
12809     const int ref_frame_cost = comp_pred
12810                                    ? ref_costs_comp[ref_frame][second_ref_frame]
12811                                    : ref_costs_single[ref_frame];
12812     const int compmode_cost =
12813         is_comp_ref_allowed(mbmi->sb_type) ? comp_inter_cost[comp_pred] : 0;
12814     const int real_compmode_cost =
12815         cm->current_frame.reference_mode == REFERENCE_MODE_SELECT
12816             ? compmode_cost
12817             : 0;
12818 
12819     if (comp_pred) {
12820       if ((sf->mode_search_skip_flags & FLAG_SKIP_COMP_BESTINTRA) &&
12821           search_state.best_mode_index >= 0 &&
12822           search_state.best_mbmode.ref_frame[0] == INTRA_FRAME)
12823         continue;
12824     }
12825 
12826     if (ref_frame == INTRA_FRAME) {
12827       if ((!cpi->oxcf.enable_smooth_intra || sf->disable_smooth_intra) &&
12828           (mbmi->mode == SMOOTH_PRED || mbmi->mode == SMOOTH_H_PRED ||
12829            mbmi->mode == SMOOTH_V_PRED))
12830         continue;
12831       if (!cpi->oxcf.enable_paeth_intra && mbmi->mode == PAETH_PRED) continue;
12832       if (sf->adaptive_mode_search > 1)
12833         if ((x->source_variance << num_pels_log2_lookup[bsize]) >
12834             search_state.best_pred_sse)
12835           continue;
12836 
12837       if (this_mode != DC_PRED) {
12838         // Only search the oblique modes if the best so far is
12839         // one of the neighboring directional modes
12840         if ((sf->mode_search_skip_flags & FLAG_SKIP_INTRA_BESTINTER) &&
12841             (this_mode >= D45_PRED && this_mode <= PAETH_PRED)) {
12842           if (search_state.best_mode_index >= 0 &&
12843               search_state.best_mbmode.ref_frame[0] > INTRA_FRAME)
12844             continue;
12845         }
12846         if (sf->mode_search_skip_flags & FLAG_SKIP_INTRA_DIRMISMATCH) {
12847           if (conditional_skipintra(this_mode, search_state.best_intra_mode))
12848             continue;
12849         }
12850       }
12851     }
12852 
12853     // Select prediction reference frames.
12854     for (i = 0; i < num_planes; i++) {
12855       xd->plane[i].pre[0] = yv12_mb[ref_frame][i];
12856       if (comp_pred) xd->plane[i].pre[1] = yv12_mb[second_ref_frame][i];
12857     }
12858 
12859     if (ref_frame == INTRA_FRAME) {
12860       intra_mode_idx_ls[intra_mode_num++] = midx;
12861       continue;
12862     } else {
12863       mbmi->angle_delta[PLANE_TYPE_Y] = 0;
12864       mbmi->angle_delta[PLANE_TYPE_UV] = 0;
12865       mbmi->filter_intra_mode_info.use_filter_intra = 0;
12866       mbmi->ref_mv_idx = 0;
12867       int64_t ref_best_rd = search_state.best_rd;
12868       {
12869         RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
12870         av1_init_rd_stats(&rd_stats);
12871         rd_stats.rate = rate2;
12872 
12873         // Point to variables that are maintained between loop iterations
12874         args.single_newmv = search_state.single_newmv;
12875         args.single_newmv_rate = search_state.single_newmv_rate;
12876         args.single_newmv_valid = search_state.single_newmv_valid;
12877         args.single_comp_cost = real_compmode_cost;
12878         args.ref_frame_cost = ref_frame_cost;
12879         if (midx < MAX_SINGLE_REF_MODES) {
12880           args.simple_rd_state = x->simple_rd_state[midx];
12881         }
12882 
12883 #if CONFIG_COLLECT_COMPONENT_TIMING
12884         start_timing(cpi, handle_inter_mode_time);
12885 #endif
12886         this_rd = handle_inter_mode(
12887             cpi, tile_data, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
12888             &disable_skip, mi_row, mi_col, &args, ref_best_rd, tmp_buf,
12889             &rd_buffers, &best_est_rd, do_tx_search, inter_modes_info);
12890 #if CONFIG_COLLECT_COMPONENT_TIMING
12891         end_timing(cpi, handle_inter_mode_time);
12892 #endif
12893         rate2 = rd_stats.rate;
12894         skippable = rd_stats.skip;
12895         distortion2 = rd_stats.dist;
12896         rate_y = rd_stats_y.rate;
12897         rate_uv = rd_stats_uv.rate;
12898       }
12899 
12900       if (sf->prune_comp_search_by_single_result > 0 &&
12901           is_inter_singleref_mode(this_mode) && args.single_ref_first_pass) {
12902         collect_single_states(x, &search_state, mbmi);
12903       }
12904 
12905       if (this_rd == INT64_MAX) continue;
12906 
12907       this_skip2 = mbmi->skip;
12908       this_rd = RDCOST(x->rdmult, rate2, distortion2);
12909       if (this_skip2) {
12910         rate_y = 0;
12911         rate_uv = 0;
12912       }
12913     }
12914 
12915     // Did this mode help.. i.e. is it the new best mode
12916     if (this_rd < search_state.best_rd || x->skip) {
12917       int mode_excluded = 0;
12918       if (comp_pred) {
12919         mode_excluded = cm->current_frame.reference_mode == SINGLE_REFERENCE;
12920       }
12921       if (!mode_excluded) {
12922         // Note index of best mode so far
12923         search_state.best_mode_index = midx;
12924 
12925         if (ref_frame == INTRA_FRAME) {
12926           /* required for left and above block mv */
12927           mbmi->mv[0].as_int = 0;
12928         } else {
12929           search_state.best_pred_sse = x->pred_sse[ref_frame];
12930         }
12931 
12932         rd_cost->rate = rate2;
12933         rd_cost->dist = distortion2;
12934         rd_cost->rdcost = this_rd;
12935         search_state.best_rd = this_rd;
12936         search_state.best_mbmode = *mbmi;
12937         search_state.best_skip2 = this_skip2;
12938         search_state.best_mode_skippable = skippable;
12939         if (do_tx_search) {
12940           // When do_tx_search == 0, handle_inter_mode won't provide correct
12941           // rate_y and rate_uv because txfm_search process is replaced by
12942           // rd estimation.
12943           // Therfore, we should avoid updating best_rate_y and best_rate_uv
12944           // here. These two values will be updated when txfm_search is called
12945           search_state.best_rate_y =
12946               rate_y +
12947               x->skip_cost[av1_get_skip_context(xd)][this_skip2 || skippable];
12948           search_state.best_rate_uv = rate_uv;
12949         }
12950         memcpy(ctx->blk_skip, x->blk_skip,
12951                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
12952       }
12953     }
12954 
12955     /* keep record of best compound/single-only prediction */
12956     if (!disable_skip && ref_frame != INTRA_FRAME) {
12957       int64_t single_rd, hybrid_rd, single_rate, hybrid_rate;
12958 
12959       if (cm->current_frame.reference_mode == REFERENCE_MODE_SELECT) {
12960         single_rate = rate2 - compmode_cost;
12961         hybrid_rate = rate2;
12962       } else {
12963         single_rate = rate2;
12964         hybrid_rate = rate2 + compmode_cost;
12965       }
12966 
12967       single_rd = RDCOST(x->rdmult, single_rate, distortion2);
12968       hybrid_rd = RDCOST(x->rdmult, hybrid_rate, distortion2);
12969 
12970       if (!comp_pred) {
12971         if (single_rd < search_state.best_pred_rd[SINGLE_REFERENCE])
12972           search_state.best_pred_rd[SINGLE_REFERENCE] = single_rd;
12973       } else {
12974         if (single_rd < search_state.best_pred_rd[COMPOUND_REFERENCE])
12975           search_state.best_pred_rd[COMPOUND_REFERENCE] = single_rd;
12976       }
12977       if (hybrid_rd < search_state.best_pred_rd[REFERENCE_MODE_SELECT])
12978         search_state.best_pred_rd[REFERENCE_MODE_SELECT] = hybrid_rd;
12979     }
12980     if (sf->drop_ref && second_ref_frame == NONE_FRAME) {
12981       // Collect data from single ref mode, and analyze data.
12982       sf_drop_ref_analyze(&search_state, mode_order, distortion2);
12983     }
12984 
12985     if (x->skip && !comp_pred) break;
12986   }
12987 
12988   release_compound_type_rd_buffers(&rd_buffers);
12989 
12990 #if CONFIG_COLLECT_COMPONENT_TIMING
12991   start_timing(cpi, do_tx_search_time);
12992 #endif
12993   if (do_tx_search_global != 1) {
12994     inter_modes_info_sort(inter_modes_info, inter_modes_info->rd_idx_pair_arr);
12995     search_state.best_rd = INT64_MAX;
12996 
12997     int64_t top_est_rd =
12998         inter_modes_info->num > 0
12999             ? inter_modes_info
13000                   ->est_rd_arr[inter_modes_info->rd_idx_pair_arr[0].idx]
13001             : INT64_MAX;
13002     for (int j = 0; j < inter_modes_info->num; ++j) {
13003       const int data_idx = inter_modes_info->rd_idx_pair_arr[j].idx;
13004       *mbmi = inter_modes_info->mbmi_arr[data_idx];
13005       int64_t curr_est_rd = inter_modes_info->est_rd_arr[data_idx];
13006       if (curr_est_rd * 0.80 > top_est_rd) break;
13007 
13008       RD_STATS rd_stats;
13009       RD_STATS rd_stats_y;
13010       RD_STATS rd_stats_uv;
13011 
13012       bool true_rd = inter_modes_info->true_rd_arr[data_idx];
13013       if (true_rd) {
13014         rd_stats = inter_modes_info->rd_cost_arr[data_idx];
13015         rd_stats_y = inter_modes_info->rd_cost_y_arr[data_idx];
13016         rd_stats_uv = inter_modes_info->rd_cost_uv_arr[data_idx];
13017         memcpy(x->blk_skip, inter_modes_info->blk_skip_arr[data_idx],
13018                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
13019       } else {
13020         const int mode_rate = inter_modes_info->mode_rate_arr[data_idx];
13021 
13022         x->skip = 0;
13023         set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
13024 
13025         // Select prediction reference frames.
13026         const int is_comp_pred = mbmi->ref_frame[1] > INTRA_FRAME;
13027         for (i = 0; i < num_planes; i++) {
13028           xd->plane[i].pre[0] = yv12_mb[mbmi->ref_frame[0]][i];
13029           if (is_comp_pred)
13030             xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
13031         }
13032 
13033         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize, 0,
13034                                       av1_num_planes(cm) - 1);
13035         if (mbmi->motion_mode == OBMC_CAUSAL)
13036           av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
13037 
13038         if (!txfm_search(cpi, tile_data, x, bsize, mi_row, mi_col, &rd_stats,
13039                          &rd_stats_y, &rd_stats_uv, mode_rate,
13040                          search_state.best_rd)) {
13041           continue;
13042         } else if (cpi->sf.inter_mode_rd_model_estimation == 1) {
13043           const int skip_ctx = av1_get_skip_context(xd);
13044           inter_mode_data_push(tile_data, mbmi->sb_type, rd_stats.sse,
13045                                rd_stats.dist,
13046                                rd_stats_y.rate + rd_stats_uv.rate +
13047                                    x->skip_cost[skip_ctx][mbmi->skip]);
13048         }
13049         rd_stats.rdcost = RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist);
13050       }
13051 
13052       if (rd_stats.rdcost < search_state.best_rd) {
13053         search_state.best_rd = rd_stats.rdcost;
13054         // Note index of best mode so far
13055         const int mode_index = get_prediction_mode_idx(
13056             mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
13057         search_state.best_mode_index = mode_index;
13058         *rd_cost = rd_stats;
13059         search_state.best_rd = rd_stats.rdcost;
13060         search_state.best_mbmode = *mbmi;
13061         search_state.best_skip2 = mbmi->skip;
13062         search_state.best_mode_skippable = rd_stats.skip;
13063         search_state.best_rate_y =
13064             rd_stats_y.rate +
13065             x->skip_cost[av1_get_skip_context(xd)][rd_stats.skip || mbmi->skip];
13066         search_state.best_rate_uv = rd_stats_uv.rate;
13067         memcpy(ctx->blk_skip, x->blk_skip,
13068                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
13069       }
13070     }
13071   }
13072 #if CONFIG_COLLECT_COMPONENT_TIMING
13073   end_timing(cpi, do_tx_search_time);
13074 #endif
13075 
13076 #if CONFIG_COLLECT_COMPONENT_TIMING
13077   start_timing(cpi, handle_intra_mode_time);
13078 #endif
13079   for (int j = 0; j < intra_mode_num; ++j) {
13080     const int mode_index = intra_mode_idx_ls[j];
13081     const MV_REFERENCE_FRAME ref_frame =
13082         av1_mode_order[mode_index].ref_frame[0];
13083     assert(av1_mode_order[mode_index].ref_frame[1] == NONE_FRAME);
13084     assert(ref_frame == INTRA_FRAME);
13085     if (sf->skip_intra_in_interframe && search_state.skip_intra_modes) break;
13086     init_mbmi(mbmi, mode_index, cm);
13087     x->skip = 0;
13088     set_ref_ptrs(cm, xd, INTRA_FRAME, NONE_FRAME);
13089 
13090     // Select prediction reference frames.
13091     for (i = 0; i < num_planes; i++) {
13092       xd->plane[i].pre[0] = yv12_mb[ref_frame][i];
13093     }
13094 
13095     RD_STATS intra_rd_stats, intra_rd_stats_y, intra_rd_stats_uv;
13096 
13097     const int ref_frame_cost = ref_costs_single[ref_frame];
13098     intra_rd_stats.rdcost = handle_intra_mode(
13099         &search_state, cpi, x, bsize, mi_row, mi_col, ref_frame_cost, ctx, 0,
13100         &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
13101     if (intra_rd_stats.rdcost < search_state.best_rd) {
13102       search_state.best_rd = intra_rd_stats.rdcost;
13103       // Note index of best mode so far
13104       search_state.best_mode_index = mode_index;
13105       *rd_cost = intra_rd_stats;
13106       search_state.best_rd = intra_rd_stats.rdcost;
13107       search_state.best_mbmode = *mbmi;
13108       search_state.best_skip2 = 0;
13109       search_state.best_mode_skippable = intra_rd_stats.skip;
13110       search_state.best_rate_y =
13111           intra_rd_stats_y.rate +
13112           x->skip_cost[av1_get_skip_context(xd)][intra_rd_stats.skip];
13113       search_state.best_rate_uv = intra_rd_stats_uv.rate;
13114       memcpy(ctx->blk_skip, x->blk_skip,
13115              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
13116     }
13117   }
13118 #if CONFIG_COLLECT_COMPONENT_TIMING
13119   end_timing(cpi, handle_intra_mode_time);
13120 #endif
13121 
13122   // In effect only when speed >= 2.
13123   sf_refine_fast_tx_type_search(
13124       cpi, x, mi_row, mi_col, rd_cost, bsize, ctx, search_state.best_mode_index,
13125       &search_state.best_mbmode, yv12_mb, search_state.best_rate_y,
13126       search_state.best_rate_uv, &search_state.best_skip2);
13127 
13128   // Only try palette mode when the best mode so far is an intra mode.
13129   if (try_palette && !is_inter_mode(search_state.best_mbmode.mode)) {
13130     search_palette_mode(cpi, x, mi_row, mi_col, rd_cost, ctx, bsize, mbmi, pmi,
13131                         ref_costs_single, &search_state);
13132   }
13133   search_state.best_mbmode.skip_mode = 0;
13134   if (cm->current_frame.skip_mode_info.skip_mode_flag &&
13135       !segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME) &&
13136       is_comp_ref_allowed(bsize)) {
13137     rd_pick_skip_mode(rd_cost, &search_state, cpi, x, bsize, mi_row, mi_col,
13138                       yv12_mb);
13139   }
13140 
13141   // Make sure that the ref_mv_idx is only nonzero when we're
13142   // using a mode which can support ref_mv_idx
13143   if (search_state.best_mbmode.ref_mv_idx != 0 &&
13144       !(search_state.best_mbmode.mode == NEWMV ||
13145         search_state.best_mbmode.mode == NEW_NEWMV ||
13146         have_nearmv_in_inter_mode(search_state.best_mbmode.mode))) {
13147     search_state.best_mbmode.ref_mv_idx = 0;
13148   }
13149 
13150   if (search_state.best_mode_index < 0 ||
13151       search_state.best_rd >= best_rd_so_far) {
13152     rd_cost->rate = INT_MAX;
13153     rd_cost->rdcost = INT64_MAX;
13154     return;
13155   }
13156 
13157   assert(
13158       (cm->interp_filter == SWITCHABLE) ||
13159       (cm->interp_filter ==
13160        av1_extract_interp_filter(search_state.best_mbmode.interp_filters, 0)) ||
13161       !is_inter_block(&search_state.best_mbmode));
13162   assert(
13163       (cm->interp_filter == SWITCHABLE) ||
13164       (cm->interp_filter ==
13165        av1_extract_interp_filter(search_state.best_mbmode.interp_filters, 1)) ||
13166       !is_inter_block(&search_state.best_mbmode));
13167 
13168   if (!cpi->rc.is_src_frame_alt_ref)
13169     av1_update_rd_thresh_fact(cm, tile_data->thresh_freq_fact,
13170                               sf->adaptive_rd_thresh, bsize,
13171                               search_state.best_mode_index);
13172 
13173   // macroblock modes
13174   *mbmi = search_state.best_mbmode;
13175   x->skip |= search_state.best_skip2;
13176 
13177   // Note: this section is needed since the mode may have been forced to
13178   // GLOBALMV by the all-zero mode handling of ref-mv.
13179   if (mbmi->mode == GLOBALMV || mbmi->mode == GLOBAL_GLOBALMV) {
13180     // Correct the interp filters for GLOBALMV
13181     if (is_nontrans_global_motion(xd, xd->mi[0])) {
13182       assert(mbmi->interp_filters ==
13183              av1_broadcast_interp_filter(
13184                  av1_unswitchable_filter(cm->interp_filter)));
13185     }
13186   }
13187 
13188   for (i = 0; i < REFERENCE_MODES; ++i) {
13189     if (search_state.best_pred_rd[i] == INT64_MAX)
13190       search_state.best_pred_diff[i] = INT_MIN;
13191     else
13192       search_state.best_pred_diff[i] =
13193           search_state.best_rd - search_state.best_pred_rd[i];
13194   }
13195 
13196   x->skip |= search_state.best_mode_skippable;
13197 
13198   assert(search_state.best_mode_index >= 0);
13199 
13200   store_coding_context(x, ctx, search_state.best_mode_index,
13201                        search_state.best_pred_diff,
13202                        search_state.best_mode_skippable);
13203 
13204   if (pmi->palette_size[1] > 0) {
13205     assert(try_palette);
13206     restore_uv_color_map(cpi, x);
13207   }
13208 }
13209 
13210 // TODO(kyslov): now this is very similar to av1_rd_pick_inter_mode_sb except:
13211 //                 it only checks non-compound mode and
13212 //                 it doesn't check palette mode
13213 //                 it doesn't refine tx search
13214 //               this function is likely to be heavily modified with nonrd mode
13215 //               decision
av1_nonrd_pick_inter_mode_sb(AV1_COMP * cpi,TileDataEnc * tile_data,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int64_t best_rd_so_far)13216 void av1_nonrd_pick_inter_mode_sb(AV1_COMP *cpi, TileDataEnc *tile_data,
13217                                   MACROBLOCK *x, int mi_row, int mi_col,
13218                                   RD_STATS *rd_cost, BLOCK_SIZE bsize,
13219                                   PICK_MODE_CONTEXT *ctx,
13220                                   int64_t best_rd_so_far) {
13221   AV1_COMMON *const cm = &cpi->common;
13222   const int num_planes = av1_num_planes(cm);
13223   const SPEED_FEATURES *const sf = &cpi->sf;
13224   MACROBLOCKD *const xd = &x->e_mbd;
13225   MB_MODE_INFO *const mbmi = xd->mi[0];
13226   const struct segmentation *const seg = &cm->seg;
13227   PREDICTION_MODE this_mode;
13228   unsigned char segment_id = mbmi->segment_id;
13229   int i;
13230   struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE];
13231   unsigned int ref_costs_single[REF_FRAMES];
13232   unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES];
13233   int *comp_inter_cost = x->comp_inter_cost[av1_get_reference_mode_context(xd)];
13234   mode_skip_mask_t mode_skip_mask;
13235   uint8_t motion_mode_skip_mask = 0;  // second pass of single ref modes
13236 
13237   InterModeSearchState search_state;
13238   init_inter_mode_search_state(&search_state, cpi, tile_data, x, bsize,
13239                                best_rd_so_far);
13240   INTERINTRA_MODE interintra_modes[REF_FRAMES] = {
13241     INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES,
13242     INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES
13243   };
13244   HandleInterModeArgs args = {
13245     { NULL },  { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE },
13246     { NULL },  { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1 },
13247     NULL,      NULL,
13248     NULL,      search_state.modelled_rd,
13249     { { 0 } }, INT_MAX,
13250     INT_MAX,   search_state.simple_rd,
13251     0,         interintra_modes,
13252     1,         NULL
13253   };
13254   for (i = 0; i < REF_FRAMES; ++i) x->pred_sse[i] = INT_MAX;
13255 
13256   av1_invalid_rd_stats(rd_cost);
13257 
13258   // Ref frames that are selected by square partition blocks.
13259   int picked_ref_frames_mask = 0;
13260   if (cpi->sf.prune_ref_frame_for_rect_partitions &&
13261       mbmi->partition != PARTITION_NONE && mbmi->partition != PARTITION_SPLIT) {
13262     // Don't enable for vert and horz partition blocks if current frame
13263     // will be used as bwd or arf2.
13264     if ((!cpi->refresh_bwd_ref_frame && !cpi->refresh_alt2_ref_frame) ||
13265         (mbmi->partition != PARTITION_VERT &&
13266          mbmi->partition != PARTITION_HORZ)) {
13267       picked_ref_frames_mask = fetch_picked_ref_frames_mask(
13268           x, bsize, cm->seq_params.mib_size, mi_row, mi_col);
13269     }
13270   }
13271 
13272   // Skip ref frames that never selected by square blocks.
13273   const int skip_ref_frame_mask =
13274       picked_ref_frames_mask ? ~picked_ref_frames_mask : 0;
13275 
13276   // init params, set frame modes, speed features
13277   set_params_nonrd_pick_inter_mode(cpi, x, &args, bsize, mi_row, mi_col,
13278                                    &mode_skip_mask, skip_ref_frame_mask,
13279                                    ref_costs_single, ref_costs_comp, yv12_mb);
13280 
13281   int64_t best_est_rd = INT64_MAX;
13282   InterModesInfo *inter_modes_info = x->inter_modes_info;
13283   inter_modes_info->num = 0;
13284 
13285   int intra_mode_num = 0;
13286   int intra_mode_idx_ls[MAX_MODES];
13287   int reach_first_comp_mode = 0;
13288 
13289   // Temporary buffers used by handle_inter_mode().
13290   uint8_t *const tmp_buf = get_buf_by_bd(xd, x->tmp_obmc_bufs[0]);
13291 
13292   CompoundTypeRdBuffers rd_buffers;
13293   alloc_compound_type_rd_buffers(cm, &rd_buffers);
13294 
13295   for (int midx = 0; midx < MAX_MODES; ++midx) {
13296     const MODE_DEFINITION *mode_order = &av1_mode_order[midx];
13297     this_mode = mode_order->mode;
13298     const MV_REFERENCE_FRAME ref_frame = mode_order->ref_frame[0];
13299     const MV_REFERENCE_FRAME second_ref_frame = mode_order->ref_frame[1];
13300     const int comp_pred = second_ref_frame > INTRA_FRAME;
13301 
13302     if (second_ref_frame != NONE_FRAME) continue;
13303 
13304     // When single ref motion search ends:
13305     // 1st pass: To evaluate single ref RD results and rewind to the beginning;
13306     // 2nd pass: To continue with compound ref search.
13307     if (sf->prune_single_motion_modes_by_simple_trans) {
13308       if (comp_pred && args.single_ref_first_pass) {
13309         args.single_ref_first_pass = 0;
13310         // Reach the first comp ref mode
13311         // Reset midx to start the 2nd pass for single ref motion search
13312         midx = -1;
13313         motion_mode_skip_mask = analyze_simple_trans_states(cpi, x);
13314         continue;
13315       }
13316       if (!comp_pred && ref_frame != INTRA_FRAME) {  // single ref mode
13317         if (args.single_ref_first_pass) {
13318           // clear stats
13319           for (int k = 0; k < MAX_REF_MV_SERCH; ++k) {
13320             x->simple_rd_state[midx][k].rd_stats.rdcost = INT64_MAX;
13321             x->simple_rd_state[midx][k].early_skipped = 0;
13322           }
13323         } else {
13324           if (motion_mode_skip_mask & (1 << ref_frame)) {
13325             continue;
13326           }
13327         }
13328       }
13329     }
13330 
13331     // Reach the first compound prediction mode
13332     if (sf->prune_comp_search_by_single_result > 0 && comp_pred &&
13333         reach_first_comp_mode == 0) {
13334       analyze_single_states(cpi, &search_state);
13335       reach_first_comp_mode = 1;
13336     }
13337     int64_t this_rd = INT64_MAX;
13338     int disable_skip = 0;
13339     int rate2 = 0;
13340     int64_t distortion2 = 0;
13341     int skippable = 0;
13342     int this_skip2 = 0;
13343 
13344     init_mbmi(mbmi, midx, cm);
13345 
13346     x->skip = 0;
13347     set_ref_ptrs(cm, xd, ref_frame, second_ref_frame);
13348 
13349     if (inter_mode_compatible_skip(cpi, x, bsize, midx)) continue;
13350 
13351     const int ret = inter_mode_search_order_independent_skip(
13352         cpi, x, bsize, midx, mi_row, mi_col, &mode_skip_mask, &search_state,
13353         skip_ref_frame_mask);
13354     if (ret == 1) continue;
13355     args.skip_motion_mode = (ret == 2);
13356 
13357     if (sf->drop_ref && comp_pred) {
13358       if (sf_check_is_drop_ref(mode_order, &search_state)) {
13359         continue;
13360       }
13361     }
13362 
13363     if (search_state.best_rd < search_state.mode_threshold[midx]) continue;
13364 
13365     if (sf->prune_comp_search_by_single_result > 0 && comp_pred) {
13366       if (compound_skip_by_single_states(cpi, &search_state, this_mode,
13367                                          ref_frame, second_ref_frame, x))
13368         continue;
13369     }
13370 
13371     const int ref_frame_cost = comp_pred
13372                                    ? ref_costs_comp[ref_frame][second_ref_frame]
13373                                    : ref_costs_single[ref_frame];
13374     const int compmode_cost =
13375         is_comp_ref_allowed(mbmi->sb_type) ? comp_inter_cost[comp_pred] : 0;
13376     const int real_compmode_cost =
13377         cm->current_frame.reference_mode == REFERENCE_MODE_SELECT
13378             ? compmode_cost
13379             : 0;
13380 
13381     if (comp_pred) {
13382       if ((sf->mode_search_skip_flags & FLAG_SKIP_COMP_BESTINTRA) &&
13383           search_state.best_mode_index >= 0 &&
13384           search_state.best_mbmode.ref_frame[0] == INTRA_FRAME)
13385         continue;
13386     }
13387 
13388     if (ref_frame == INTRA_FRAME) {
13389       if (!cpi->oxcf.enable_smooth_intra &&
13390           (mbmi->mode == SMOOTH_PRED || mbmi->mode == SMOOTH_H_PRED ||
13391            mbmi->mode == SMOOTH_V_PRED))
13392         continue;
13393       if (!cpi->oxcf.enable_paeth_intra && mbmi->mode == PAETH_PRED) continue;
13394       if (sf->adaptive_mode_search > 1)
13395         if ((x->source_variance << num_pels_log2_lookup[bsize]) >
13396             search_state.best_pred_sse)
13397           continue;
13398 
13399       if (this_mode != DC_PRED) {
13400         // Only search the oblique modes if the best so far is
13401         // one of the neighboring directional modes
13402         if ((sf->mode_search_skip_flags & FLAG_SKIP_INTRA_BESTINTER) &&
13403             (this_mode >= D45_PRED && this_mode <= PAETH_PRED)) {
13404           if (search_state.best_mode_index >= 0 &&
13405               search_state.best_mbmode.ref_frame[0] > INTRA_FRAME)
13406             continue;
13407         }
13408         if (sf->mode_search_skip_flags & FLAG_SKIP_INTRA_DIRMISMATCH) {
13409           if (conditional_skipintra(this_mode, search_state.best_intra_mode))
13410             continue;
13411         }
13412       }
13413     }
13414 
13415     // Select prediction reference frames.
13416     for (i = 0; i < num_planes; i++) {
13417       xd->plane[i].pre[0] = yv12_mb[ref_frame][i];
13418       if (comp_pred) xd->plane[i].pre[1] = yv12_mb[second_ref_frame][i];
13419     }
13420 
13421     if (ref_frame == INTRA_FRAME) {
13422       intra_mode_idx_ls[intra_mode_num++] = midx;
13423       continue;
13424     } else {
13425       mbmi->angle_delta[PLANE_TYPE_Y] = 0;
13426       mbmi->angle_delta[PLANE_TYPE_UV] = 0;
13427       mbmi->filter_intra_mode_info.use_filter_intra = 0;
13428       mbmi->ref_mv_idx = 0;
13429       int64_t ref_best_rd = search_state.best_rd;
13430       {
13431         RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
13432         av1_init_rd_stats(&rd_stats);
13433         rd_stats.rate = rate2;
13434 
13435         // Point to variables that are maintained between loop iterations
13436         args.single_newmv = search_state.single_newmv;
13437         args.single_newmv_rate = search_state.single_newmv_rate;
13438         args.single_newmv_valid = search_state.single_newmv_valid;
13439         args.single_comp_cost = real_compmode_cost;
13440         args.ref_frame_cost = ref_frame_cost;
13441         if (midx < MAX_SINGLE_REF_MODES) {
13442           args.simple_rd_state = x->simple_rd_state[midx];
13443         }
13444         this_rd = handle_inter_mode(
13445             cpi, tile_data, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
13446             &disable_skip, mi_row, mi_col, &args, ref_best_rd, tmp_buf,
13447             &rd_buffers, &best_est_rd, 0, inter_modes_info);
13448         rate2 = rd_stats.rate;
13449         skippable = rd_stats.skip;
13450         distortion2 = rd_stats.dist;
13451       }
13452 
13453       if (sf->prune_comp_search_by_single_result > 0 &&
13454           is_inter_singleref_mode(this_mode) && args.single_ref_first_pass) {
13455         collect_single_states(x, &search_state, mbmi);
13456       }
13457 
13458       if (this_rd == INT64_MAX) continue;
13459 
13460       this_skip2 = mbmi->skip;
13461       this_rd = RDCOST(x->rdmult, rate2, distortion2);
13462     }
13463 
13464     // Did this mode help.. i.e. is it the new best mode
13465     if (this_rd < search_state.best_rd || x->skip) {
13466       int mode_excluded = 0;
13467       if (comp_pred) {
13468         mode_excluded = cm->current_frame.reference_mode == SINGLE_REFERENCE;
13469       }
13470       if (!mode_excluded) {
13471         // Note index of best mode so far
13472         search_state.best_mode_index = midx;
13473 
13474         if (ref_frame == INTRA_FRAME) {
13475           /* required for left and above block mv */
13476           mbmi->mv[0].as_int = 0;
13477         } else {
13478           search_state.best_pred_sse = x->pred_sse[ref_frame];
13479         }
13480 
13481         rd_cost->rate = rate2;
13482         rd_cost->dist = distortion2;
13483         rd_cost->rdcost = this_rd;
13484         search_state.best_rd = this_rd;
13485         search_state.best_mbmode = *mbmi;
13486         search_state.best_skip2 = this_skip2;
13487         search_state.best_mode_skippable = skippable;
13488         memcpy(ctx->blk_skip, x->blk_skip,
13489                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
13490       }
13491     }
13492 
13493     /* keep record of best compound/single-only prediction */
13494     if (!disable_skip && ref_frame != INTRA_FRAME) {
13495       int64_t single_rd, hybrid_rd, single_rate, hybrid_rate;
13496 
13497       if (cm->current_frame.reference_mode == REFERENCE_MODE_SELECT) {
13498         single_rate = rate2 - compmode_cost;
13499         hybrid_rate = rate2;
13500       } else {
13501         single_rate = rate2;
13502         hybrid_rate = rate2 + compmode_cost;
13503       }
13504 
13505       single_rd = RDCOST(x->rdmult, single_rate, distortion2);
13506       hybrid_rd = RDCOST(x->rdmult, hybrid_rate, distortion2);
13507 
13508       if (!comp_pred) {
13509         if (single_rd < search_state.best_pred_rd[SINGLE_REFERENCE])
13510           search_state.best_pred_rd[SINGLE_REFERENCE] = single_rd;
13511       } else {
13512         if (single_rd < search_state.best_pred_rd[COMPOUND_REFERENCE])
13513           search_state.best_pred_rd[COMPOUND_REFERENCE] = single_rd;
13514       }
13515       if (hybrid_rd < search_state.best_pred_rd[REFERENCE_MODE_SELECT])
13516         search_state.best_pred_rd[REFERENCE_MODE_SELECT] = hybrid_rd;
13517     }
13518     if (sf->drop_ref && second_ref_frame == NONE_FRAME) {
13519       // Collect data from single ref mode, and analyze data.
13520       sf_drop_ref_analyze(&search_state, mode_order, distortion2);
13521     }
13522 
13523     if (x->skip && !comp_pred) break;
13524   }
13525 
13526   release_compound_type_rd_buffers(&rd_buffers);
13527 
13528   inter_modes_info_sort(inter_modes_info, inter_modes_info->rd_idx_pair_arr);
13529   search_state.best_rd = INT64_MAX;
13530 
13531   if (inter_modes_info->num > 0) {
13532     const int data_idx = inter_modes_info->rd_idx_pair_arr[0].idx;
13533     *mbmi = inter_modes_info->mbmi_arr[data_idx];
13534     const int mode_rate = inter_modes_info->mode_rate_arr[data_idx];
13535 
13536     x->skip = 0;
13537     set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
13538 
13539     // Select prediction reference frames.
13540     const int is_comp_pred = mbmi->ref_frame[1] > INTRA_FRAME;
13541     for (i = 0; i < num_planes; i++) {
13542       xd->plane[i].pre[0] = yv12_mb[mbmi->ref_frame[0]][i];
13543       if (is_comp_pred) xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
13544     }
13545 
13546     RD_STATS rd_stats;
13547     RD_STATS rd_stats_y;
13548     RD_STATS rd_stats_uv;
13549 
13550     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize, 0,
13551                                   av1_num_planes(cm) - 1);
13552     if (mbmi->motion_mode == OBMC_CAUSAL)
13553       av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
13554 
13555     if (txfm_search(cpi, tile_data, x, bsize, mi_row, mi_col, &rd_stats,
13556                     &rd_stats_y, &rd_stats_uv, mode_rate,
13557                     search_state.best_rd)) {
13558       if (cpi->sf.inter_mode_rd_model_estimation == 1) {
13559         const int skip_ctx = av1_get_skip_context(xd);
13560         inter_mode_data_push(tile_data, mbmi->sb_type, rd_stats.sse,
13561                              rd_stats.dist,
13562                              rd_stats_y.rate + rd_stats_uv.rate +
13563                                  x->skip_cost[skip_ctx][mbmi->skip]);
13564       }
13565       rd_stats.rdcost = RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist);
13566 
13567       if (rd_stats.rdcost < search_state.best_rd) {
13568         search_state.best_rd = rd_stats.rdcost;
13569         // Note index of best mode so far
13570         const int mode_index = get_prediction_mode_idx(
13571             mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
13572         search_state.best_mode_index = mode_index;
13573         *rd_cost = rd_stats;
13574         search_state.best_rd = rd_stats.rdcost;
13575         search_state.best_mbmode = *mbmi;
13576         search_state.best_skip2 = mbmi->skip;
13577         search_state.best_mode_skippable = rd_stats.skip;
13578         search_state.best_rate_y =
13579             rd_stats_y.rate +
13580             x->skip_cost[av1_get_skip_context(xd)][rd_stats.skip || mbmi->skip];
13581         search_state.best_rate_uv = rd_stats_uv.rate;
13582         memcpy(ctx->blk_skip, x->blk_skip,
13583                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
13584       }
13585     }
13586   }
13587 
13588   for (int j = 0; j < intra_mode_num; ++j) {
13589     const int mode_index = intra_mode_idx_ls[j];
13590     const MV_REFERENCE_FRAME ref_frame =
13591         av1_mode_order[mode_index].ref_frame[0];
13592     assert(av1_mode_order[mode_index].ref_frame[1] == NONE_FRAME);
13593     assert(ref_frame == INTRA_FRAME);
13594     if (sf->skip_intra_in_interframe && search_state.skip_intra_modes) break;
13595     init_mbmi(mbmi, mode_index, cm);
13596     x->skip = 0;
13597     set_ref_ptrs(cm, xd, INTRA_FRAME, NONE_FRAME);
13598 
13599     // Select prediction reference frames.
13600     for (i = 0; i < num_planes; i++) {
13601       xd->plane[i].pre[0] = yv12_mb[ref_frame][i];
13602     }
13603 
13604     RD_STATS intra_rd_stats, intra_rd_stats_y, intra_rd_stats_uv;
13605 
13606     const int ref_frame_cost = ref_costs_single[ref_frame];
13607     intra_rd_stats.rdcost = handle_intra_mode(
13608         &search_state, cpi, x, bsize, mi_row, mi_col, ref_frame_cost, ctx, 0,
13609         &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
13610     if (intra_rd_stats.rdcost < search_state.best_rd) {
13611       search_state.best_rd = intra_rd_stats.rdcost;
13612       // Note index of best mode so far
13613       search_state.best_mode_index = mode_index;
13614       *rd_cost = intra_rd_stats;
13615       search_state.best_rd = intra_rd_stats.rdcost;
13616       search_state.best_mbmode = *mbmi;
13617       search_state.best_skip2 = 0;
13618       search_state.best_mode_skippable = intra_rd_stats.skip;
13619       search_state.best_rate_y =
13620           intra_rd_stats_y.rate +
13621           x->skip_cost[av1_get_skip_context(xd)][intra_rd_stats.skip];
13622       search_state.best_rate_uv = intra_rd_stats_uv.rate;
13623       memcpy(ctx->blk_skip, x->blk_skip,
13624              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
13625     }
13626   }
13627 
13628   search_state.best_mbmode.skip_mode = 0;
13629   if (cm->current_frame.skip_mode_info.skip_mode_flag &&
13630       !segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME) &&
13631       is_comp_ref_allowed(bsize)) {
13632     rd_pick_skip_mode(rd_cost, &search_state, cpi, x, bsize, mi_row, mi_col,
13633                       yv12_mb);
13634   }
13635 
13636   // Make sure that the ref_mv_idx is only nonzero when we're
13637   // using a mode which can support ref_mv_idx
13638   if (search_state.best_mbmode.ref_mv_idx != 0 &&
13639       !(search_state.best_mbmode.mode == NEWMV ||
13640         search_state.best_mbmode.mode == NEW_NEWMV ||
13641         have_nearmv_in_inter_mode(search_state.best_mbmode.mode))) {
13642     search_state.best_mbmode.ref_mv_idx = 0;
13643   }
13644 
13645   if (search_state.best_mode_index < 0 ||
13646       search_state.best_rd >= best_rd_so_far) {
13647     rd_cost->rate = INT_MAX;
13648     rd_cost->rdcost = INT64_MAX;
13649     return;
13650   }
13651 
13652   assert(
13653       (cm->interp_filter == SWITCHABLE) ||
13654       (cm->interp_filter ==
13655        av1_extract_interp_filter(search_state.best_mbmode.interp_filters, 0)) ||
13656       !is_inter_block(&search_state.best_mbmode));
13657   assert(
13658       (cm->interp_filter == SWITCHABLE) ||
13659       (cm->interp_filter ==
13660        av1_extract_interp_filter(search_state.best_mbmode.interp_filters, 1)) ||
13661       !is_inter_block(&search_state.best_mbmode));
13662 
13663   if (!cpi->rc.is_src_frame_alt_ref)
13664     av1_update_rd_thresh_fact(cm, tile_data->thresh_freq_fact,
13665                               sf->adaptive_rd_thresh, bsize,
13666                               search_state.best_mode_index);
13667 
13668   // macroblock modes
13669   *mbmi = search_state.best_mbmode;
13670   x->skip |= search_state.best_skip2;
13671 
13672   // Note: this section is needed since the mode may have been forced to
13673   // GLOBALMV by the all-zero mode handling of ref-mv.
13674   if (mbmi->mode == GLOBALMV || mbmi->mode == GLOBAL_GLOBALMV) {
13675     // Correct the interp filters for GLOBALMV
13676     if (is_nontrans_global_motion(xd, xd->mi[0])) {
13677       assert(mbmi->interp_filters ==
13678              av1_broadcast_interp_filter(
13679                  av1_unswitchable_filter(cm->interp_filter)));
13680     }
13681   }
13682 
13683   for (i = 0; i < REFERENCE_MODES; ++i) {
13684     if (search_state.best_pred_rd[i] == INT64_MAX)
13685       search_state.best_pred_diff[i] = INT_MIN;
13686     else
13687       search_state.best_pred_diff[i] =
13688           search_state.best_rd - search_state.best_pred_rd[i];
13689   }
13690 
13691   x->skip |= search_state.best_mode_skippable;
13692 
13693   assert(search_state.best_mode_index >= 0);
13694 
13695   store_coding_context(x, ctx, search_state.best_mode_index,
13696                        search_state.best_pred_diff,
13697                        search_state.best_mode_skippable);
13698 }
13699 
av1_rd_pick_inter_mode_sb_seg_skip(const AV1_COMP * cpi,TileDataEnc * tile_data,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int64_t best_rd_so_far)13700 void av1_rd_pick_inter_mode_sb_seg_skip(const AV1_COMP *cpi,
13701                                         TileDataEnc *tile_data, MACROBLOCK *x,
13702                                         int mi_row, int mi_col,
13703                                         RD_STATS *rd_cost, BLOCK_SIZE bsize,
13704                                         PICK_MODE_CONTEXT *ctx,
13705                                         int64_t best_rd_so_far) {
13706   const AV1_COMMON *const cm = &cpi->common;
13707   MACROBLOCKD *const xd = &x->e_mbd;
13708   MB_MODE_INFO *const mbmi = xd->mi[0];
13709   unsigned char segment_id = mbmi->segment_id;
13710   const int comp_pred = 0;
13711   int i;
13712   int64_t best_pred_diff[REFERENCE_MODES];
13713   unsigned int ref_costs_single[REF_FRAMES];
13714   unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES];
13715   int *comp_inter_cost = x->comp_inter_cost[av1_get_reference_mode_context(xd)];
13716   InterpFilter best_filter = SWITCHABLE;
13717   int64_t this_rd = INT64_MAX;
13718   int rate2 = 0;
13719   const int64_t distortion2 = 0;
13720   (void)mi_row;
13721   (void)mi_col;
13722 
13723   av1_collect_neighbors_ref_counts(xd);
13724 
13725   estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
13726                            ref_costs_comp);
13727 
13728   for (i = 0; i < REF_FRAMES; ++i) x->pred_sse[i] = INT_MAX;
13729   for (i = LAST_FRAME; i < REF_FRAMES; ++i) x->pred_mv_sad[i] = INT_MAX;
13730 
13731   rd_cost->rate = INT_MAX;
13732 
13733   assert(segfeature_active(&cm->seg, segment_id, SEG_LVL_SKIP));
13734 
13735   mbmi->palette_mode_info.palette_size[0] = 0;
13736   mbmi->palette_mode_info.palette_size[1] = 0;
13737   mbmi->filter_intra_mode_info.use_filter_intra = 0;
13738   mbmi->mode = GLOBALMV;
13739   mbmi->motion_mode = SIMPLE_TRANSLATION;
13740   mbmi->uv_mode = UV_DC_PRED;
13741   if (segfeature_active(&cm->seg, segment_id, SEG_LVL_REF_FRAME))
13742     mbmi->ref_frame[0] = get_segdata(&cm->seg, segment_id, SEG_LVL_REF_FRAME);
13743   else
13744     mbmi->ref_frame[0] = LAST_FRAME;
13745   mbmi->ref_frame[1] = NONE_FRAME;
13746   mbmi->mv[0].as_int =
13747       gm_get_motion_vector(&cm->global_motion[mbmi->ref_frame[0]],
13748                            cm->allow_high_precision_mv, bsize, mi_col, mi_row,
13749                            cm->cur_frame_force_integer_mv)
13750           .as_int;
13751   mbmi->tx_size = max_txsize_lookup[bsize];
13752   x->skip = 1;
13753 
13754   mbmi->ref_mv_idx = 0;
13755 
13756   mbmi->motion_mode = SIMPLE_TRANSLATION;
13757   av1_count_overlappable_neighbors(cm, xd, mi_row, mi_col);
13758   if (is_motion_variation_allowed_bsize(bsize) && !has_second_ref(mbmi)) {
13759     int pts[SAMPLES_ARRAY_SIZE], pts_inref[SAMPLES_ARRAY_SIZE];
13760     mbmi->num_proj_ref = findSamples(cm, xd, mi_row, mi_col, pts, pts_inref);
13761     // Select the samples according to motion vector difference
13762     if (mbmi->num_proj_ref > 1)
13763       mbmi->num_proj_ref = selectSamples(&mbmi->mv[0].as_mv, pts, pts_inref,
13764                                          mbmi->num_proj_ref, bsize);
13765   }
13766 
13767   set_default_interp_filters(mbmi, cm->interp_filter);
13768 
13769   if (cm->interp_filter != SWITCHABLE) {
13770     best_filter = cm->interp_filter;
13771   } else {
13772     best_filter = EIGHTTAP_REGULAR;
13773     if (av1_is_interp_needed(xd) && av1_is_interp_search_needed(xd) &&
13774         x->source_variance >= cpi->sf.disable_filter_search_var_thresh) {
13775       int rs;
13776       int best_rs = INT_MAX;
13777       for (i = 0; i < SWITCHABLE_FILTERS; ++i) {
13778         mbmi->interp_filters = av1_broadcast_interp_filter(i);
13779         rs = av1_get_switchable_rate(cm, x, xd);
13780         if (rs < best_rs) {
13781           best_rs = rs;
13782           best_filter = av1_extract_interp_filter(mbmi->interp_filters, 0);
13783         }
13784       }
13785     }
13786   }
13787   // Set the appropriate filter
13788   mbmi->interp_filters = av1_broadcast_interp_filter(best_filter);
13789   rate2 += av1_get_switchable_rate(cm, x, xd);
13790 
13791   if (cm->current_frame.reference_mode == REFERENCE_MODE_SELECT)
13792     rate2 += comp_inter_cost[comp_pred];
13793 
13794   // Estimate the reference frame signaling cost and add it
13795   // to the rolling cost variable.
13796   rate2 += ref_costs_single[LAST_FRAME];
13797   this_rd = RDCOST(x->rdmult, rate2, distortion2);
13798 
13799   rd_cost->rate = rate2;
13800   rd_cost->dist = distortion2;
13801   rd_cost->rdcost = this_rd;
13802 
13803   if (this_rd >= best_rd_so_far) {
13804     rd_cost->rate = INT_MAX;
13805     rd_cost->rdcost = INT64_MAX;
13806     return;
13807   }
13808 
13809   assert((cm->interp_filter == SWITCHABLE) ||
13810          (cm->interp_filter ==
13811           av1_extract_interp_filter(mbmi->interp_filters, 0)));
13812 
13813   av1_update_rd_thresh_fact(cm, tile_data->thresh_freq_fact,
13814                             cpi->sf.adaptive_rd_thresh, bsize, THR_GLOBALMV);
13815 
13816   av1_zero(best_pred_diff);
13817 
13818   store_coding_context(x, ctx, THR_GLOBALMV, best_pred_diff, 0);
13819 }
13820 
13821 struct calc_target_weighted_pred_ctxt {
13822   const MACROBLOCK *x;
13823   const uint8_t *tmp;
13824   int tmp_stride;
13825   int overlap;
13826 };
13827 
calc_target_weighted_pred_above(MACROBLOCKD * xd,int rel_mi_col,uint8_t nb_mi_width,MB_MODE_INFO * nb_mi,void * fun_ctxt,const int num_planes)13828 static INLINE void calc_target_weighted_pred_above(
13829     MACROBLOCKD *xd, int rel_mi_col, uint8_t nb_mi_width, MB_MODE_INFO *nb_mi,
13830     void *fun_ctxt, const int num_planes) {
13831   (void)nb_mi;
13832   (void)num_planes;
13833 
13834   struct calc_target_weighted_pred_ctxt *ctxt =
13835       (struct calc_target_weighted_pred_ctxt *)fun_ctxt;
13836 
13837   const int bw = xd->n4_w << MI_SIZE_LOG2;
13838   const uint8_t *const mask1d = av1_get_obmc_mask(ctxt->overlap);
13839 
13840   int32_t *wsrc = ctxt->x->wsrc_buf + (rel_mi_col * MI_SIZE);
13841   int32_t *mask = ctxt->x->mask_buf + (rel_mi_col * MI_SIZE);
13842   const uint8_t *tmp = ctxt->tmp + rel_mi_col * MI_SIZE;
13843   const int is_hbd = is_cur_buf_hbd(xd);
13844 
13845   if (!is_hbd) {
13846     for (int row = 0; row < ctxt->overlap; ++row) {
13847       const uint8_t m0 = mask1d[row];
13848       const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
13849       for (int col = 0; col < nb_mi_width * MI_SIZE; ++col) {
13850         wsrc[col] = m1 * tmp[col];
13851         mask[col] = m0;
13852       }
13853       wsrc += bw;
13854       mask += bw;
13855       tmp += ctxt->tmp_stride;
13856     }
13857   } else {
13858     const uint16_t *tmp16 = CONVERT_TO_SHORTPTR(tmp);
13859 
13860     for (int row = 0; row < ctxt->overlap; ++row) {
13861       const uint8_t m0 = mask1d[row];
13862       const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
13863       for (int col = 0; col < nb_mi_width * MI_SIZE; ++col) {
13864         wsrc[col] = m1 * tmp16[col];
13865         mask[col] = m0;
13866       }
13867       wsrc += bw;
13868       mask += bw;
13869       tmp16 += ctxt->tmp_stride;
13870     }
13871   }
13872 }
13873 
calc_target_weighted_pred_left(MACROBLOCKD * xd,int rel_mi_row,uint8_t nb_mi_height,MB_MODE_INFO * nb_mi,void * fun_ctxt,const int num_planes)13874 static INLINE void calc_target_weighted_pred_left(
13875     MACROBLOCKD *xd, int rel_mi_row, uint8_t nb_mi_height, MB_MODE_INFO *nb_mi,
13876     void *fun_ctxt, const int num_planes) {
13877   (void)nb_mi;
13878   (void)num_planes;
13879 
13880   struct calc_target_weighted_pred_ctxt *ctxt =
13881       (struct calc_target_weighted_pred_ctxt *)fun_ctxt;
13882 
13883   const int bw = xd->n4_w << MI_SIZE_LOG2;
13884   const uint8_t *const mask1d = av1_get_obmc_mask(ctxt->overlap);
13885 
13886   int32_t *wsrc = ctxt->x->wsrc_buf + (rel_mi_row * MI_SIZE * bw);
13887   int32_t *mask = ctxt->x->mask_buf + (rel_mi_row * MI_SIZE * bw);
13888   const uint8_t *tmp = ctxt->tmp + (rel_mi_row * MI_SIZE * ctxt->tmp_stride);
13889   const int is_hbd = is_cur_buf_hbd(xd);
13890 
13891   if (!is_hbd) {
13892     for (int row = 0; row < nb_mi_height * MI_SIZE; ++row) {
13893       for (int col = 0; col < ctxt->overlap; ++col) {
13894         const uint8_t m0 = mask1d[col];
13895         const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
13896         wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
13897                     (tmp[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
13898         mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
13899       }
13900       wsrc += bw;
13901       mask += bw;
13902       tmp += ctxt->tmp_stride;
13903     }
13904   } else {
13905     const uint16_t *tmp16 = CONVERT_TO_SHORTPTR(tmp);
13906 
13907     for (int row = 0; row < nb_mi_height * MI_SIZE; ++row) {
13908       for (int col = 0; col < ctxt->overlap; ++col) {
13909         const uint8_t m0 = mask1d[col];
13910         const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
13911         wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
13912                     (tmp16[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
13913         mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
13914       }
13915       wsrc += bw;
13916       mask += bw;
13917       tmp16 += ctxt->tmp_stride;
13918     }
13919   }
13920 }
13921 
13922 // This function has a structure similar to av1_build_obmc_inter_prediction
13923 //
13924 // The OBMC predictor is computed as:
13925 //
13926 //  PObmc(x,y) =
13927 //    AOM_BLEND_A64(Mh(x),
13928 //                  AOM_BLEND_A64(Mv(y), P(x,y), PAbove(x,y)),
13929 //                  PLeft(x, y))
13930 //
13931 // Scaling up by AOM_BLEND_A64_MAX_ALPHA ** 2 and omitting the intermediate
13932 // rounding, this can be written as:
13933 //
13934 //  AOM_BLEND_A64_MAX_ALPHA * AOM_BLEND_A64_MAX_ALPHA * Pobmc(x,y) =
13935 //    Mh(x) * Mv(y) * P(x,y) +
13936 //      Mh(x) * Cv(y) * Pabove(x,y) +
13937 //      AOM_BLEND_A64_MAX_ALPHA * Ch(x) * PLeft(x, y)
13938 //
13939 // Where :
13940 //
13941 //  Cv(y) = AOM_BLEND_A64_MAX_ALPHA - Mv(y)
13942 //  Ch(y) = AOM_BLEND_A64_MAX_ALPHA - Mh(y)
13943 //
13944 // This function computes 'wsrc' and 'mask' as:
13945 //
13946 //  wsrc(x, y) =
13947 //    AOM_BLEND_A64_MAX_ALPHA * AOM_BLEND_A64_MAX_ALPHA * src(x, y) -
13948 //      Mh(x) * Cv(y) * Pabove(x,y) +
13949 //      AOM_BLEND_A64_MAX_ALPHA * Ch(x) * PLeft(x, y)
13950 //
13951 //  mask(x, y) = Mh(x) * Mv(y)
13952 //
13953 // These can then be used to efficiently approximate the error for any
13954 // predictor P in the context of the provided neighbouring predictors by
13955 // computing:
13956 //
13957 //  error(x, y) =
13958 //    wsrc(x, y) - mask(x, y) * P(x, y) / (AOM_BLEND_A64_MAX_ALPHA ** 2)
13959 //
calc_target_weighted_pred(const AV1_COMMON * cm,const MACROBLOCK * x,const MACROBLOCKD * xd,int mi_row,int mi_col,const uint8_t * above,int above_stride,const uint8_t * left,int left_stride)13960 static void calc_target_weighted_pred(const AV1_COMMON *cm, const MACROBLOCK *x,
13961                                       const MACROBLOCKD *xd, int mi_row,
13962                                       int mi_col, const uint8_t *above,
13963                                       int above_stride, const uint8_t *left,
13964                                       int left_stride) {
13965   const BLOCK_SIZE bsize = xd->mi[0]->sb_type;
13966   const int bw = xd->n4_w << MI_SIZE_LOG2;
13967   const int bh = xd->n4_h << MI_SIZE_LOG2;
13968   int32_t *mask_buf = x->mask_buf;
13969   int32_t *wsrc_buf = x->wsrc_buf;
13970 
13971   const int is_hbd = is_cur_buf_hbd(xd);
13972   const int src_scale = AOM_BLEND_A64_MAX_ALPHA * AOM_BLEND_A64_MAX_ALPHA;
13973 
13974   // plane 0 should not be subsampled
13975   assert(xd->plane[0].subsampling_x == 0);
13976   assert(xd->plane[0].subsampling_y == 0);
13977 
13978   av1_zero_array(wsrc_buf, bw * bh);
13979   for (int i = 0; i < bw * bh; ++i) mask_buf[i] = AOM_BLEND_A64_MAX_ALPHA;
13980 
13981   // handle above row
13982   if (xd->up_available) {
13983     const int overlap =
13984         AOMMIN(block_size_high[bsize], block_size_high[BLOCK_64X64]) >> 1;
13985     struct calc_target_weighted_pred_ctxt ctxt = { x, above, above_stride,
13986                                                    overlap };
13987     foreach_overlappable_nb_above(cm, (MACROBLOCKD *)xd, mi_col,
13988                                   max_neighbor_obmc[mi_size_wide_log2[bsize]],
13989                                   calc_target_weighted_pred_above, &ctxt);
13990   }
13991 
13992   for (int i = 0; i < bw * bh; ++i) {
13993     wsrc_buf[i] *= AOM_BLEND_A64_MAX_ALPHA;
13994     mask_buf[i] *= AOM_BLEND_A64_MAX_ALPHA;
13995   }
13996 
13997   // handle left column
13998   if (xd->left_available) {
13999     const int overlap =
14000         AOMMIN(block_size_wide[bsize], block_size_wide[BLOCK_64X64]) >> 1;
14001     struct calc_target_weighted_pred_ctxt ctxt = { x, left, left_stride,
14002                                                    overlap };
14003     foreach_overlappable_nb_left(cm, (MACROBLOCKD *)xd, mi_row,
14004                                  max_neighbor_obmc[mi_size_high_log2[bsize]],
14005                                  calc_target_weighted_pred_left, &ctxt);
14006   }
14007 
14008   if (!is_hbd) {
14009     const uint8_t *src = x->plane[0].src.buf;
14010 
14011     for (int row = 0; row < bh; ++row) {
14012       for (int col = 0; col < bw; ++col) {
14013         wsrc_buf[col] = src[col] * src_scale - wsrc_buf[col];
14014       }
14015       wsrc_buf += bw;
14016       src += x->plane[0].src.stride;
14017     }
14018   } else {
14019     const uint16_t *src = CONVERT_TO_SHORTPTR(x->plane[0].src.buf);
14020 
14021     for (int row = 0; row < bh; ++row) {
14022       for (int col = 0; col < bw; ++col) {
14023         wsrc_buf[col] = src[col] * src_scale - wsrc_buf[col];
14024       }
14025       wsrc_buf += bw;
14026       src += x->plane[0].src.stride;
14027     }
14028   }
14029 }
14030 
14031 /* Use standard 3x3 Sobel matrix. Macro so it can be used for either high or
14032    low bit-depth arrays. */
14033 #define SOBEL_X(src, stride, i, j)                       \
14034   ((src)[((i)-1) + (stride) * ((j)-1)] -                 \
14035    (src)[((i) + 1) + (stride) * ((j)-1)] +  /* NOLINT */ \
14036    2 * (src)[((i)-1) + (stride) * (j)] -    /* NOLINT */ \
14037    2 * (src)[((i) + 1) + (stride) * (j)] +  /* NOLINT */ \
14038    (src)[((i)-1) + (stride) * ((j) + 1)] -  /* NOLINT */ \
14039    (src)[((i) + 1) + (stride) * ((j) + 1)]) /* NOLINT */
14040 #define SOBEL_Y(src, stride, i, j)                       \
14041   ((src)[((i)-1) + (stride) * ((j)-1)] +                 \
14042    2 * (src)[(i) + (stride) * ((j)-1)] +    /* NOLINT */ \
14043    (src)[((i) + 1) + (stride) * ((j)-1)] -  /* NOLINT */ \
14044    (src)[((i)-1) + (stride) * ((j) + 1)] -  /* NOLINT */ \
14045    2 * (src)[(i) + (stride) * ((j) + 1)] -  /* NOLINT */ \
14046    (src)[((i) + 1) + (stride) * ((j) + 1)]) /* NOLINT */
14047 
sobel(const uint8_t * input,int stride,int i,int j,bool high_bd)14048 sobel_xy sobel(const uint8_t *input, int stride, int i, int j, bool high_bd) {
14049   int16_t s_x;
14050   int16_t s_y;
14051   if (high_bd) {
14052     const uint16_t *src = CONVERT_TO_SHORTPTR(input);
14053     s_x = SOBEL_X(src, stride, i, j);
14054     s_y = SOBEL_Y(src, stride, i, j);
14055   } else {
14056     s_x = SOBEL_X(input, stride, i, j);
14057     s_y = SOBEL_Y(input, stride, i, j);
14058   }
14059   sobel_xy r = { .x = s_x, .y = s_y };
14060   return r;
14061 }
14062 
14063 // 8-tap Gaussian convolution filter with sigma = 1.3, sums to 128,
14064 // all co-efficients must be even.
14065 DECLARE_ALIGNED(16, static const int16_t, gauss_filter[8]) = { 2,  12, 30, 40,
14066                                                                30, 12, 2,  0 };
14067 
gaussian_blur(const uint8_t * src,int src_stride,int w,int h,uint8_t * dst,bool high_bd,int bd)14068 void gaussian_blur(const uint8_t *src, int src_stride, int w, int h,
14069                    uint8_t *dst, bool high_bd, int bd) {
14070   ConvolveParams conv_params = get_conv_params(0, 0, bd);
14071   InterpFilterParams filter = { .filter_ptr = gauss_filter,
14072                                 .taps = 8,
14073                                 .subpel_shifts = 0,
14074                                 .interp_filter = EIGHTTAP_REGULAR };
14075   // Requirements from the vector-optimized implementations.
14076   assert(h % 4 == 0);
14077   assert(w % 8 == 0);
14078   // Because we use an eight tap filter, the stride should be at least 7 + w.
14079   assert(src_stride >= w + 7);
14080   if (high_bd) {
14081     av1_highbd_convolve_2d_sr(CONVERT_TO_SHORTPTR(src), src_stride,
14082                               CONVERT_TO_SHORTPTR(dst), w, w, h, &filter,
14083                               &filter, 0, 0, &conv_params, bd);
14084   } else {
14085     av1_convolve_2d_sr(src, src_stride, dst, w, w, h, &filter, &filter, 0, 0,
14086                        &conv_params);
14087   }
14088 }
14089 
edge_probability(const uint8_t * input,int w,int h,bool high_bd,int bd)14090 static EdgeInfo edge_probability(const uint8_t *input, int w, int h,
14091                                  bool high_bd, int bd) {
14092   // The probability of an edge in the whole image is the same as the highest
14093   // probability of an edge for any individual pixel. Use Sobel as the metric
14094   // for finding an edge.
14095   uint16_t highest = 0;
14096   uint16_t highest_x = 0;
14097   uint16_t highest_y = 0;
14098   // Ignore the 1 pixel border around the image for the computation.
14099   for (int j = 1; j < h - 1; ++j) {
14100     for (int i = 1; i < w - 1; ++i) {
14101       sobel_xy g = sobel(input, w, i, j, high_bd);
14102       // Scale down to 8-bit to get same output regardless of bit depth.
14103       int16_t g_x = g.x >> (bd - 8);
14104       int16_t g_y = g.y >> (bd - 8);
14105       uint16_t magnitude = (uint16_t)sqrt(g_x * g_x + g_y * g_y);
14106       highest = AOMMAX(highest, magnitude);
14107       highest_x = AOMMAX(highest_x, g_x);
14108       highest_y = AOMMAX(highest_y, g_y);
14109     }
14110   }
14111   EdgeInfo ei = { .magnitude = highest, .x = highest_x, .y = highest_y };
14112   return ei;
14113 }
14114 
14115 /* Uses most of the Canny edge detection algorithm to find if there are any
14116  * edges in the image.
14117  */
av1_edge_exists(const uint8_t * src,int src_stride,int w,int h,bool high_bd,int bd)14118 EdgeInfo av1_edge_exists(const uint8_t *src, int src_stride, int w, int h,
14119                          bool high_bd, int bd) {
14120   if (w < 3 || h < 3) {
14121     EdgeInfo n = { .magnitude = 0, .x = 0, .y = 0 };
14122     return n;
14123   }
14124   uint8_t *blurred;
14125   if (high_bd) {
14126     blurred = CONVERT_TO_BYTEPTR(aom_memalign(32, sizeof(uint16_t) * w * h));
14127   } else {
14128     blurred = (uint8_t *)aom_memalign(32, sizeof(uint8_t) * w * h);
14129   }
14130   gaussian_blur(src, src_stride, w, h, blurred, high_bd, bd);
14131   // Skip the non-maximum suppression step in Canny edge detection. We just
14132   // want a probability of an edge existing in the buffer, which is determined
14133   // by the strongest edge in it -- we don't need to eliminate the weaker
14134   // edges. Use Sobel for the edge detection.
14135   EdgeInfo prob = edge_probability(blurred, w, h, high_bd, bd);
14136   if (high_bd) {
14137     aom_free(CONVERT_TO_SHORTPTR(blurred));
14138   } else {
14139     aom_free(blurred);
14140   }
14141   return prob;
14142 }
14143