1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/common/cfl.h"
13 #include "av1/common/reconintra.h"
14 #include "av1/encoder/encodetxb.h"
15 #include "av1/encoder/hybrid_fwd_txfm.h"
16 #include "av1/common/idct.h"
17 #include "av1/encoder/model_rd.h"
18 #include "av1/encoder/random.h"
19 #include "av1/encoder/rdopt_utils.h"
20 #include "av1/encoder/tx_prune_model_weights.h"
21 #include "av1/encoder/tx_search.h"
22 
23 struct rdcost_block_args {
24   const AV1_COMP *cpi;
25   MACROBLOCK *x;
26   ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
27   ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
28   RD_STATS rd_stats;
29   int64_t current_rd;
30   int64_t best_rd;
31   int exit_early;
32   int incomplete_exit;
33   int use_fast_coef_costing;
34   FAST_TX_SEARCH_MODE ftxs_mode;
35   int skip_trellis;
36 };
37 
38 typedef struct {
39   int64_t rd;
40   int txb_entropy_ctx;
41   TX_TYPE tx_type;
42 } TxCandidateInfo;
43 
44 typedef struct {
45   int leaf;
46   int8_t children[4];
47 } RD_RECORD_IDX_NODE;
48 
49 // origin_threshold * 128 / 100
50 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
51   {
52       64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
53       68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
54   },
55   {
56       88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
57       68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
58   },
59   {
60       90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
61       74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
62   },
63 };
64 
65 // lookup table for predict_skip_flag
66 // int max_tx_size = max_txsize_rect_lookup[bsize];
67 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
68 //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
69 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
70   TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
71   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
72   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
73   TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
74 };
75 
find_tx_size_rd_info(TXB_RD_RECORD * cur_record,const uint32_t hash)76 static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record,
77                                 const uint32_t hash) {
78   // Linear search through the circular buffer to find matching hash.
79   for (int i = cur_record->index_start - 1; i >= 0; i--) {
80     if (cur_record->hash_vals[i] == hash) return i;
81   }
82   for (int i = cur_record->num - 1; i >= cur_record->index_start; i--) {
83     if (cur_record->hash_vals[i] == hash) return i;
84   }
85   int index;
86   // If not found - add new RD info into the buffer and return its index
87   if (cur_record->num < TX_SIZE_RD_RECORD_BUFFER_LEN) {
88     index = (cur_record->index_start + cur_record->num) %
89             TX_SIZE_RD_RECORD_BUFFER_LEN;
90     cur_record->num++;
91   } else {
92     index = cur_record->index_start;
93     cur_record->index_start =
94         (cur_record->index_start + 1) % TX_SIZE_RD_RECORD_BUFFER_LEN;
95   }
96 
97   cur_record->hash_vals[index] = hash;
98   av1_zero(cur_record->tx_rd_info[index]);
99   return index;
100 }
101 
102 static const RD_RECORD_IDX_NODE rd_record_tree_8x8[] = {
103   { 1, { 0 } },
104 };
105 
106 static const RD_RECORD_IDX_NODE rd_record_tree_8x16[] = {
107   { 0, { 1, 2, -1, -1 } },
108   { 1, { 0, 0, 0, 0 } },
109   { 1, { 0, 0, 0, 0 } },
110 };
111 
112 static const RD_RECORD_IDX_NODE rd_record_tree_16x8[] = {
113   { 0, { 1, 2, -1, -1 } },
114   { 1, { 0 } },
115   { 1, { 0 } },
116 };
117 
118 static const RD_RECORD_IDX_NODE rd_record_tree_16x16[] = {
119   { 0, { 1, 2, 3, 4 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } },
120 };
121 
122 static const RD_RECORD_IDX_NODE rd_record_tree_1_2[] = {
123   { 0, { 1, 2, -1, -1 } },
124   { 0, { 3, 4, 5, 6 } },
125   { 0, { 7, 8, 9, 10 } },
126 };
127 
128 static const RD_RECORD_IDX_NODE rd_record_tree_2_1[] = {
129   { 0, { 1, 2, -1, -1 } },
130   { 0, { 3, 4, 7, 8 } },
131   { 0, { 5, 6, 9, 10 } },
132 };
133 
134 static const RD_RECORD_IDX_NODE rd_record_tree_sqr[] = {
135   { 0, { 1, 2, 3, 4 } },     { 0, { 5, 6, 9, 10 } },    { 0, { 7, 8, 11, 12 } },
136   { 0, { 13, 14, 17, 18 } }, { 0, { 15, 16, 19, 20 } },
137 };
138 
139 static const RD_RECORD_IDX_NODE rd_record_tree_64x128[] = {
140   { 0, { 2, 3, 4, 5 } },     { 0, { 6, 7, 8, 9 } },
141   { 0, { 10, 11, 14, 15 } }, { 0, { 12, 13, 16, 17 } },
142   { 0, { 18, 19, 22, 23 } }, { 0, { 20, 21, 24, 25 } },
143   { 0, { 26, 27, 30, 31 } }, { 0, { 28, 29, 32, 33 } },
144   { 0, { 34, 35, 38, 39 } }, { 0, { 36, 37, 40, 41 } },
145 };
146 
147 static const RD_RECORD_IDX_NODE rd_record_tree_128x64[] = {
148   { 0, { 2, 3, 6, 7 } },     { 0, { 4, 5, 8, 9 } },
149   { 0, { 10, 11, 18, 19 } }, { 0, { 12, 13, 20, 21 } },
150   { 0, { 14, 15, 22, 23 } }, { 0, { 16, 17, 24, 25 } },
151   { 0, { 26, 27, 34, 35 } }, { 0, { 28, 29, 36, 37 } },
152   { 0, { 30, 31, 38, 39 } }, { 0, { 32, 33, 40, 41 } },
153 };
154 
155 static const RD_RECORD_IDX_NODE rd_record_tree_128x128[] = {
156   { 0, { 4, 5, 8, 9 } },     { 0, { 6, 7, 10, 11 } },
157   { 0, { 12, 13, 16, 17 } }, { 0, { 14, 15, 18, 19 } },
158   { 0, { 20, 21, 28, 29 } }, { 0, { 22, 23, 30, 31 } },
159   { 0, { 24, 25, 32, 33 } }, { 0, { 26, 27, 34, 35 } },
160   { 0, { 36, 37, 44, 45 } }, { 0, { 38, 39, 46, 47 } },
161   { 0, { 40, 41, 48, 49 } }, { 0, { 42, 43, 50, 51 } },
162   { 0, { 52, 53, 60, 61 } }, { 0, { 54, 55, 62, 63 } },
163   { 0, { 56, 57, 64, 65 } }, { 0, { 58, 59, 66, 67 } },
164   { 0, { 68, 69, 76, 77 } }, { 0, { 70, 71, 78, 79 } },
165   { 0, { 72, 73, 80, 81 } }, { 0, { 74, 75, 82, 83 } },
166 };
167 
168 static const RD_RECORD_IDX_NODE rd_record_tree_1_4[] = {
169   { 0, { 1, -1, 2, -1 } },
170   { 0, { 3, 4, -1, -1 } },
171   { 0, { 5, 6, -1, -1 } },
172 };
173 
174 static const RD_RECORD_IDX_NODE rd_record_tree_4_1[] = {
175   { 0, { 1, 2, -1, -1 } },
176   { 0, { 3, 4, -1, -1 } },
177   { 0, { 5, 6, -1, -1 } },
178 };
179 
180 static const RD_RECORD_IDX_NODE *rd_record_tree[BLOCK_SIZES_ALL] = {
181   NULL,                    // BLOCK_4X4
182   NULL,                    // BLOCK_4X8
183   NULL,                    // BLOCK_8X4
184   rd_record_tree_8x8,      // BLOCK_8X8
185   rd_record_tree_8x16,     // BLOCK_8X16
186   rd_record_tree_16x8,     // BLOCK_16X8
187   rd_record_tree_16x16,    // BLOCK_16X16
188   rd_record_tree_1_2,      // BLOCK_16X32
189   rd_record_tree_2_1,      // BLOCK_32X16
190   rd_record_tree_sqr,      // BLOCK_32X32
191   rd_record_tree_1_2,      // BLOCK_32X64
192   rd_record_tree_2_1,      // BLOCK_64X32
193   rd_record_tree_sqr,      // BLOCK_64X64
194   rd_record_tree_64x128,   // BLOCK_64X128
195   rd_record_tree_128x64,   // BLOCK_128X64
196   rd_record_tree_128x128,  // BLOCK_128X128
197   NULL,                    // BLOCK_4X16
198   NULL,                    // BLOCK_16X4
199   rd_record_tree_1_4,      // BLOCK_8X32
200   rd_record_tree_4_1,      // BLOCK_32X8
201   rd_record_tree_1_4,      // BLOCK_16X64
202   rd_record_tree_4_1,      // BLOCK_64X16
203 };
204 
205 static const int rd_record_tree_size[BLOCK_SIZES_ALL] = {
206   0,                                                            // BLOCK_4X4
207   0,                                                            // BLOCK_4X8
208   0,                                                            // BLOCK_8X4
209   sizeof(rd_record_tree_8x8) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X8
210   sizeof(rd_record_tree_8x16) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_8X16
211   sizeof(rd_record_tree_16x8) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_16X8
212   sizeof(rd_record_tree_16x16) / sizeof(RD_RECORD_IDX_NODE),    // BLOCK_16X16
213   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X32
214   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X16
215   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X32
216   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X64
217   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X32
218   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X64
219   sizeof(rd_record_tree_64x128) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_64X128
220   sizeof(rd_record_tree_128x64) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_128X64
221   sizeof(rd_record_tree_128x128) / sizeof(RD_RECORD_IDX_NODE),  // BLOCK_128X128
222   0,                                                            // BLOCK_4X16
223   0,                                                            // BLOCK_16X4
224   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X32
225   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X8
226   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X64
227   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X16
228 };
229 
init_rd_record_tree(TXB_RD_INFO_NODE * tree,BLOCK_SIZE bsize)230 static INLINE void init_rd_record_tree(TXB_RD_INFO_NODE *tree,
231                                        BLOCK_SIZE bsize) {
232   const RD_RECORD_IDX_NODE *rd_record = rd_record_tree[bsize];
233   const int size = rd_record_tree_size[bsize];
234   for (int i = 0; i < size; ++i) {
235     if (rd_record[i].leaf) {
236       av1_zero(tree[i].children);
237     } else {
238       for (int j = 0; j < 4; ++j) {
239         const int8_t idx = rd_record[i].children[j];
240         tree[i].children[j] = idx > 0 ? &tree[idx] : NULL;
241       }
242     }
243   }
244 }
245 
246 // Go through all TX blocks that could be used in TX size search, compute
247 // residual hash values for them and find matching RD info that stores previous
248 // RD search results for these TX blocks. The idea is to prevent repeated
249 // rate/distortion computations that happen because of the combination of
250 // partition and TX size search. The resulting RD info records are returned in
251 // the form of a quadtree for easier access in actual TX size search.
find_tx_size_rd_records(MACROBLOCK * x,BLOCK_SIZE bsize,TXB_RD_INFO_NODE * dst_rd_info)252 static int find_tx_size_rd_records(MACROBLOCK *x, BLOCK_SIZE bsize,
253                                    TXB_RD_INFO_NODE *dst_rd_info) {
254   TXB_RD_RECORD *rd_records_table[4] = { x->txb_rd_record_8X8,
255                                          x->txb_rd_record_16X16,
256                                          x->txb_rd_record_32X32,
257                                          x->txb_rd_record_64X64 };
258   const TX_SIZE max_square_tx_size = max_txsize_lookup[bsize];
259   const int bw = block_size_wide[bsize];
260   const int bh = block_size_high[bsize];
261 
262   // Hashing is performed only for square TX sizes larger than TX_4X4
263   if (max_square_tx_size < TX_8X8) return 0;
264   const int diff_stride = bw;
265   const struct macroblock_plane *const p = &x->plane[0];
266   const int16_t *diff = &p->src_diff[0];
267   init_rd_record_tree(dst_rd_info, bsize);
268   // Coordinates of the top-left corner of current block within the superblock
269   // measured in pixels:
270   const int mi_row = x->e_mbd.mi_row;
271   const int mi_col = x->e_mbd.mi_col;
272   const int mi_row_in_sb = (mi_row % MAX_MIB_SIZE) << MI_SIZE_LOG2;
273   const int mi_col_in_sb = (mi_col % MAX_MIB_SIZE) << MI_SIZE_LOG2;
274   int cur_rd_info_idx = 0;
275   int cur_tx_depth = 0;
276   TX_SIZE cur_tx_size = max_txsize_rect_lookup[bsize];
277   while (cur_tx_depth <= MAX_VARTX_DEPTH) {
278     const int cur_tx_bw = tx_size_wide[cur_tx_size];
279     const int cur_tx_bh = tx_size_high[cur_tx_size];
280     if (cur_tx_bw < 8 || cur_tx_bh < 8) break;
281     const TX_SIZE next_tx_size = sub_tx_size_map[cur_tx_size];
282     const int tx_size_idx = cur_tx_size - TX_8X8;
283     for (int row = 0; row < bh; row += cur_tx_bh) {
284       for (int col = 0; col < bw; col += cur_tx_bw) {
285         if (cur_tx_bw != cur_tx_bh) {
286           // Use dummy nodes for all rectangular transforms within the
287           // TX size search tree.
288           dst_rd_info[cur_rd_info_idx].rd_info_array = NULL;
289         } else {
290           // Get spatial location of this TX block within the superblock
291           // (measured in cur_tx_bsize units).
292           const int row_in_sb = (mi_row_in_sb + row) / cur_tx_bh;
293           const int col_in_sb = (mi_col_in_sb + col) / cur_tx_bw;
294 
295           int16_t hash_data[MAX_SB_SQUARE];
296           int16_t *cur_hash_row = hash_data;
297           const int16_t *cur_diff_row = diff + row * diff_stride + col;
298           for (int i = 0; i < cur_tx_bh; i++) {
299             memcpy(cur_hash_row, cur_diff_row, sizeof(*hash_data) * cur_tx_bw);
300             cur_hash_row += cur_tx_bw;
301             cur_diff_row += diff_stride;
302           }
303           const int hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
304                                                 (uint8_t *)hash_data,
305                                                 2 * cur_tx_bw * cur_tx_bh);
306           // Find corresponding RD info based on the hash value.
307           const int record_idx =
308               row_in_sb * (MAX_MIB_SIZE >> (tx_size_idx + 1)) + col_in_sb;
309           TXB_RD_RECORD *records = &rd_records_table[tx_size_idx][record_idx];
310           int idx = find_tx_size_rd_info(records, hash);
311           dst_rd_info[cur_rd_info_idx].rd_info_array =
312               &records->tx_rd_info[idx];
313         }
314         ++cur_rd_info_idx;
315       }
316     }
317     cur_tx_size = next_tx_size;
318     ++cur_tx_depth;
319   }
320   return 1;
321 }
322 
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)323 static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
324   const int rows = block_size_high[bsize];
325   const int cols = block_size_wide[bsize];
326   const int16_t *diff = x->plane[0].src_diff;
327   const uint32_t hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
328                                              (uint8_t *)diff, 2 * rows * cols);
329   return (hash << 5) + bsize;
330 }
331 
find_mb_rd_info(const MB_RD_RECORD * const mb_rd_record,const int64_t ref_best_rd,const uint32_t hash)332 static INLINE int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
333                                       const int64_t ref_best_rd,
334                                       const uint32_t hash) {
335   int32_t match_index = -1;
336   if (ref_best_rd != INT64_MAX) {
337     for (int i = 0; i < mb_rd_record->num; ++i) {
338       const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
339       // If there is a match in the tx_rd_record, fetch the RD decision and
340       // terminate early.
341       if (mb_rd_record->tx_rd_info[index].hash_value == hash) {
342         match_index = index;
343         break;
344       }
345     }
346   }
347   return match_index;
348 }
349 
fetch_tx_rd_info(int n4,const MB_RD_INFO * const tx_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)350 static AOM_INLINE void fetch_tx_rd_info(int n4,
351                                         const MB_RD_INFO *const tx_rd_info,
352                                         RD_STATS *const rd_stats,
353                                         MACROBLOCK *const x) {
354   MACROBLOCKD *const xd = &x->e_mbd;
355   MB_MODE_INFO *const mbmi = xd->mi[0];
356   mbmi->tx_size = tx_rd_info->tx_size;
357   memcpy(x->blk_skip, tx_rd_info->blk_skip,
358          sizeof(tx_rd_info->blk_skip[0]) * n4);
359   av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
360   av1_copy_array(xd->tx_type_map, tx_rd_info->tx_type_map, n4);
361   *rd_stats = tx_rd_info->rd_stats;
362 }
363 
364 // Compute the pixel domain distortion from diff on all visible 4x4s in the
365 // transform block.
pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8)366 static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
367                                       int blk_row, int blk_col,
368                                       const BLOCK_SIZE plane_bsize,
369                                       const BLOCK_SIZE tx_bsize,
370                                       unsigned int *block_mse_q8) {
371   int visible_rows, visible_cols;
372   const MACROBLOCKD *xd = &x->e_mbd;
373   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
374                      NULL, &visible_cols, &visible_rows);
375   const int diff_stride = block_size_wide[plane_bsize];
376   const int16_t *diff = x->plane[plane].src_diff;
377 
378   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
379   uint64_t sse =
380       aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
381   if (block_mse_q8 != NULL) {
382     if (visible_cols > 0 && visible_rows > 0)
383       *block_mse_q8 =
384           (unsigned int)((256 * sse) / (visible_cols * visible_rows));
385     else
386       *block_mse_q8 = UINT_MAX;
387   }
388   return sse;
389 }
390 
391 // Uses simple features on top of DCT coefficients to quickly predict
392 // whether optimal RD decision is to skip encoding the residual.
393 // The sse value is stored in dist.
predict_skip_flag(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)394 static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
395                              int reduced_tx_set) {
396   const int bw = block_size_wide[bsize];
397   const int bh = block_size_high[bsize];
398   const MACROBLOCKD *xd = &x->e_mbd;
399   const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
400 
401   *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
402 
403   const int64_t mse = *dist / bw / bh;
404   // Normalized quantizer takes the transform upscaling factor (8 for tx size
405   // smaller than 32) into account.
406   const int16_t normalized_dc_q = dc_q >> 3;
407   const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
408   // For faster early skip decision, use dist to compare against threshold so
409   // that quality risk is less for the skip=1 decision. Otherwise, use mse
410   // since the fwd_txfm coeff checks will take care of quality
411   // TODO(any): Use dist to return 0 when predict_skip_level is 1
412   int64_t pred_err = (x->predict_skip_level >= 2) ? *dist : mse;
413   // Predict not to skip when error is larger than threshold.
414   if (pred_err > mse_thresh) return 0;
415   // Return as skip otherwise for aggressive early skip
416   else if (x->predict_skip_level >= 2)
417     return 1;
418 
419   const int max_tx_size = max_predict_sf_tx_size[bsize];
420   const int tx_h = tx_size_high[max_tx_size];
421   const int tx_w = tx_size_wide[max_tx_size];
422   DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
423   TxfmParam param;
424   param.tx_type = DCT_DCT;
425   param.tx_size = max_tx_size;
426   param.bd = xd->bd;
427   param.is_hbd = is_cur_buf_hbd(xd);
428   param.lossless = 0;
429   param.tx_set_type = av1_get_ext_tx_set_type(
430       param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
431   const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
432   const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
433   const int16_t *src_diff = x->plane[0].src_diff;
434   const int n_coeff = tx_w * tx_h;
435   const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
436   const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
437   const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
438   for (int row = 0; row < bh; row += tx_h) {
439     for (int col = 0; col < bw; col += tx_w) {
440       av1_fwd_txfm(src_diff + col, coefs, bw, &param);
441       // Operating on TX domain, not pixels; we want the QTX quantizers
442       const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
443       if (dc_coef >= dc_thresh) return 0;
444       for (int i = 1; i < n_coeff; ++i) {
445         const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
446         if (ac_coef >= ac_thresh) return 0;
447       }
448     }
449     src_diff += tx_h * bw;
450   }
451   return 1;
452 }
453 
454 // Used to set proper context for early termination with skip = 1.
set_skip_flag(MACROBLOCK * x,RD_STATS * rd_stats,int bsize,int64_t dist)455 static AOM_INLINE void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats,
456                                      int bsize, int64_t dist) {
457   MACROBLOCKD *const xd = &x->e_mbd;
458   MB_MODE_INFO *const mbmi = xd->mi[0];
459   const int n4 = bsize_to_num_blk(bsize);
460   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
461   memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
462   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
463   mbmi->tx_size = tx_size;
464   for (int i = 0; i < n4; ++i) set_blk_skip(x, 0, i, 1);
465   rd_stats->skip = 1;
466   if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
467   rd_stats->dist = rd_stats->sse = (dist << 4);
468   // Though decision is to make the block as skip based on luma stats,
469   // it is possible that block becomes non skip after chroma rd. In addition
470   // intermediate non skip costs calculated by caller function will be
471   // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
472   // accounted). Hence intermediate rate is populated to code the luma tx blks
473   // as skip, the caller function based on final rd decision (i.e., skip vs
474   // non-skip) sets the final rate accordingly. Here the rate populated
475   // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
476   // size possible) in the current block. Eg: For 128*128 block, rate would be
477   // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
478   // block as 'all zeros'
479   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
480   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
481   av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
482   ENTROPY_CONTEXT *ta = ctxa;
483   ENTROPY_CONTEXT *tl = ctxl;
484   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
485   TXB_CTX txb_ctx;
486   get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
487   const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
488                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
489   rd_stats->rate = zero_blk_rate *
490                    (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
491                    (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
492 }
493 
save_tx_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * tx_rd_record)494 static AOM_INLINE void save_tx_rd_info(int n4, uint32_t hash,
495                                        const MACROBLOCK *const x,
496                                        const RD_STATS *const rd_stats,
497                                        MB_RD_RECORD *tx_rd_record) {
498   int index;
499   if (tx_rd_record->num < RD_RECORD_BUFFER_LEN) {
500     index =
501         (tx_rd_record->index_start + tx_rd_record->num) % RD_RECORD_BUFFER_LEN;
502     ++tx_rd_record->num;
503   } else {
504     index = tx_rd_record->index_start;
505     tx_rd_record->index_start =
506         (tx_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
507   }
508   MB_RD_INFO *const tx_rd_info = &tx_rd_record->tx_rd_info[index];
509   const MACROBLOCKD *const xd = &x->e_mbd;
510   const MB_MODE_INFO *const mbmi = xd->mi[0];
511   tx_rd_info->hash_value = hash;
512   tx_rd_info->tx_size = mbmi->tx_size;
513   memcpy(tx_rd_info->blk_skip, x->blk_skip,
514          sizeof(tx_rd_info->blk_skip[0]) * n4);
515   av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
516   av1_copy_array(tx_rd_info->tx_type_map, xd->tx_type_map, n4);
517   tx_rd_info->rd_stats = *rd_stats;
518 }
519 
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf,int tx_size_search_method)520 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
521                                  const SPEED_FEATURES *sf,
522                                  int tx_size_search_method) {
523   if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
524 
525   if (sf->tx_sf.tx_size_search_lgr_block) {
526     if (mi_width > mi_size_wide[BLOCK_64X64] ||
527         mi_height > mi_size_high[BLOCK_64X64])
528       return MAX_VARTX_DEPTH;
529   }
530 
531   if (is_inter) {
532     return (mi_height != mi_width)
533                ? sf->tx_sf.inter_tx_size_search_init_depth_rect
534                : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
535   } else {
536     return (mi_height != mi_width)
537                ? sf->tx_sf.intra_tx_size_search_init_depth_rect
538                : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
539   }
540 }
541 
542 static AOM_INLINE void select_tx_block(
543     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
544     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
545     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
546     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
547     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
548     TXB_RD_INFO_NODE *rd_info_node);
549 
550 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
551 // 0: Do not collect any RD stats
552 // 1: Collect RD stats for transform units
553 // 2: Collect RD stats for partition units
554 #if CONFIG_COLLECT_RD_STATS
555 
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)556 static AOM_INLINE void get_energy_distribution_fine(
557     const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
558     const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
559     double *verdist) {
560   const int bw = block_size_wide[bsize];
561   const int bh = block_size_high[bsize];
562   unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
563 
564   if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
565     // Special cases: calculate 'esq' values manually, as we don't have 'vf'
566     // functions for the 16 (very small) sub-blocks of this block.
567     const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
568     const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
569     assert(bw <= 32);
570     assert(bh <= 32);
571     assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
572     if (cpi->common.seq_params.use_highbitdepth) {
573       const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
574       const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
575       for (int i = 0; i < bh; ++i)
576         for (int j = 0; j < bw; ++j) {
577           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
578           esq[index] +=
579               (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
580               (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
581         }
582     } else {
583       for (int i = 0; i < bh; ++i)
584         for (int j = 0; j < bw; ++j) {
585           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
586           esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
587                         (src[j + i * src_stride] - dst[j + i * dst_stride]);
588         }
589     }
590   } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
591     const int f_index =
592         (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
593     assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
594     const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
595     assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
596     assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
597     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
598     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
599                             &esq[1]);
600     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
601                             &esq[2]);
602     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
603                             dst_stride, &esq[3]);
604     src += bh / 4 * src_stride;
605     dst += bh / 4 * dst_stride;
606 
607     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
608     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
609                             &esq[5]);
610     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
611                             &esq[6]);
612     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
613                             dst_stride, &esq[7]);
614     src += bh / 4 * src_stride;
615     dst += bh / 4 * dst_stride;
616 
617     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
618     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
619                             &esq[9]);
620     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
621                             &esq[10]);
622     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
623                             dst_stride, &esq[11]);
624     src += bh / 4 * src_stride;
625     dst += bh / 4 * dst_stride;
626 
627     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
628     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
629                             &esq[13]);
630     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
631                             &esq[14]);
632     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
633                             dst_stride, &esq[15]);
634   }
635 
636   double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
637                  esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
638                  esq[12] + esq[13] + esq[14] + esq[15];
639   if (total > 0) {
640     const double e_recip = 1.0 / total;
641     hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
642     hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
643     hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
644     if (need_4th) {
645       hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
646     }
647     verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
648     verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
649     verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
650     if (need_4th) {
651       verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
652     }
653   } else {
654     hordist[0] = verdist[0] = 0.25;
655     hordist[1] = verdist[1] = 0.25;
656     hordist[2] = verdist[2] = 0.25;
657     if (need_4th) {
658       hordist[3] = verdist[3] = 0.25;
659     }
660   }
661 }
662 
get_sse_norm(const int16_t * diff,int stride,int w,int h)663 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
664   double sum = 0.0;
665   for (int j = 0; j < h; ++j) {
666     for (int i = 0; i < w; ++i) {
667       const int err = diff[j * stride + i];
668       sum += err * err;
669     }
670   }
671   assert(w > 0 && h > 0);
672   return sum / (w * h);
673 }
674 
get_sad_norm(const int16_t * diff,int stride,int w,int h)675 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
676   double sum = 0.0;
677   for (int j = 0; j < h; ++j) {
678     for (int i = 0; i < w; ++i) {
679       sum += abs(diff[j * stride + i]);
680     }
681   }
682   assert(w > 0 && h > 0);
683   return sum / (w * h);
684 }
685 
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)686 static AOM_INLINE void get_2x2_normalized_sses_and_sads(
687     const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
688     int src_stride, const uint8_t *const dst, int dst_stride,
689     const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
690     double *const sad_norm_arr) {
691   const BLOCK_SIZE tx_bsize_half =
692       get_partition_subsize(tx_bsize, PARTITION_SPLIT);
693   if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
694     const int half_width = block_size_wide[tx_bsize] / 2;
695     const int half_height = block_size_high[tx_bsize] / 2;
696     for (int row = 0; row < 2; ++row) {
697       for (int col = 0; col < 2; ++col) {
698         const int16_t *const this_src_diff =
699             src_diff + row * half_height * diff_stride + col * half_width;
700         if (sse_norm_arr) {
701           sse_norm_arr[row * 2 + col] =
702               get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
703         }
704         if (sad_norm_arr) {
705           sad_norm_arr[row * 2 + col] =
706               get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
707         }
708       }
709     }
710   } else {  // use function pointers to calculate stats
711     const int half_width = block_size_wide[tx_bsize_half];
712     const int half_height = block_size_high[tx_bsize_half];
713     const int num_samples_half = half_width * half_height;
714     for (int row = 0; row < 2; ++row) {
715       for (int col = 0; col < 2; ++col) {
716         const uint8_t *const this_src =
717             src + row * half_height * src_stride + col * half_width;
718         const uint8_t *const this_dst =
719             dst + row * half_height * dst_stride + col * half_width;
720 
721         if (sse_norm_arr) {
722           unsigned int this_sse;
723           cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
724                                         dst_stride, &this_sse);
725           sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
726         }
727 
728         if (sad_norm_arr) {
729           const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf(
730               this_src, src_stride, this_dst, dst_stride);
731           sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
732         }
733       }
734     }
735   }
736 }
737 
738 #if CONFIG_COLLECT_RD_STATS == 1
get_mean(const int16_t * diff,int stride,int w,int h)739 static double get_mean(const int16_t *diff, int stride, int w, int h) {
740   double sum = 0.0;
741   for (int j = 0; j < h; ++j) {
742     for (int i = 0; i < w; ++i) {
743       sum += diff[j * stride + i];
744     }
745   }
746   assert(w > 0 && h > 0);
747   return sum / (w * h);
748 }
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)749 static AOM_INLINE void PrintTransformUnitStats(
750     const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
751     int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
752     TX_TYPE tx_type, int64_t rd) {
753   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
754 
755   // Generate small sample to restrict output size.
756   static unsigned int seed = 21743;
757   if (lcg_rand16(&seed) % 256 > 0) return;
758 
759   const char output_file[] = "tu_stats.txt";
760   FILE *fout = fopen(output_file, "a");
761   if (!fout) return;
762 
763   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
764   const MACROBLOCKD *const xd = &x->e_mbd;
765   const int plane = 0;
766   struct macroblock_plane *const p = &x->plane[plane];
767   const struct macroblockd_plane *const pd = &xd->plane[plane];
768   const int txw = tx_size_wide[tx_size];
769   const int txh = tx_size_high[tx_size];
770   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
771   const int q_step = p->dequant_QTX[1] >> dequant_shift;
772   const int num_samples = txw * txh;
773 
774   const double rate_norm = (double)rd_stats->rate / num_samples;
775   const double dist_norm = (double)rd_stats->dist / num_samples;
776 
777   fprintf(fout, "%g %g", rate_norm, dist_norm);
778 
779   const int src_stride = p->src.stride;
780   const uint8_t *const src =
781       &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
782   const int dst_stride = pd->dst.stride;
783   const uint8_t *const dst =
784       &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
785   unsigned int sse;
786   cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
787   const double sse_norm = (double)sse / num_samples;
788 
789   const unsigned int sad =
790       cpi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
791   const double sad_norm = (double)sad / num_samples;
792 
793   fprintf(fout, " %g %g", sse_norm, sad_norm);
794 
795   const int diff_stride = block_size_wide[plane_bsize];
796   const int16_t *const src_diff =
797       &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
798 
799   double sse_norm_arr[4], sad_norm_arr[4];
800   get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
801                                    dst_stride, src_diff, diff_stride,
802                                    sse_norm_arr, sad_norm_arr);
803   for (int i = 0; i < 4; ++i) {
804     fprintf(fout, " %g", sse_norm_arr[i]);
805   }
806   for (int i = 0; i < 4; ++i) {
807     fprintf(fout, " %g", sad_norm_arr[i]);
808   }
809 
810   const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
811   const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
812 
813   fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
814           tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
815 
816   int model_rate;
817   int64_t model_dist;
818   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
819                                    &model_rate, &model_dist);
820   const double model_rate_norm = (double)model_rate / num_samples;
821   const double model_dist_norm = (double)model_dist / num_samples;
822   fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
823 
824   const double mean = get_mean(src_diff, diff_stride, txw, txh);
825   float hor_corr, vert_corr;
826   av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
827                                   &vert_corr);
828   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
829 
830   double hdist[4] = { 0 }, vdist[4] = { 0 };
831   get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
832                                1, hdist, vdist);
833   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
834           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
835 
836   fprintf(fout, " %d %" PRId64, x->rdmult, rd);
837 
838   fprintf(fout, "\n");
839   fclose(fout);
840 }
841 #endif  // CONFIG_COLLECT_RD_STATS == 1
842 
843 #if CONFIG_COLLECT_RD_STATS >= 2
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)844 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
845   const AV1_COMMON *cm = &cpi->common;
846   const int num_planes = av1_num_planes(cm);
847   const MACROBLOCKD *xd = &x->e_mbd;
848   const MB_MODE_INFO *mbmi = xd->mi[0];
849   int64_t total_sse = 0;
850   for (int plane = 0; plane < num_planes; ++plane) {
851     const struct macroblock_plane *const p = &x->plane[plane];
852     const struct macroblockd_plane *const pd = &xd->plane[plane];
853     const BLOCK_SIZE bs = get_plane_block_size(mbmi->sb_type, pd->subsampling_x,
854                                                pd->subsampling_y);
855     unsigned int sse;
856 
857     if (x->skip_chroma_rd && plane) continue;
858 
859     cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
860                        &sse);
861     total_sse += sse;
862   }
863   total_sse <<= 4;
864   return total_sse;
865 }
866 
get_est_rate_dist(const TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)867 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
868                              int64_t sse, int *est_residue_cost,
869                              int64_t *est_dist) {
870   aom_clear_system_state();
871   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
872   if (md->ready) {
873     if (sse < md->dist_mean) {
874       *est_residue_cost = 0;
875       *est_dist = sse;
876     } else {
877       *est_dist = (int64_t)round(md->dist_mean);
878       const double est_ld = md->a * sse + md->b;
879       // Clamp estimated rate cost by INT_MAX / 2.
880       // TODO(angiebird@google.com): find better solution than clamping.
881       if (fabs(est_ld) < 1e-2) {
882         *est_residue_cost = INT_MAX / 2;
883       } else {
884         double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
885         if (est_residue_cost_dbl < 0) {
886           *est_residue_cost = 0;
887         } else {
888           *est_residue_cost =
889               (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
890         }
891       }
892       if (*est_residue_cost <= 0) {
893         *est_residue_cost = 0;
894         *est_dist = sse;
895       }
896     }
897     return 1;
898   }
899   return 0;
900 }
901 
get_highbd_diff_mean(const uint8_t * src8,int src_stride,const uint8_t * dst8,int dst_stride,int w,int h)902 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
903                                    const uint8_t *dst8, int dst_stride, int w,
904                                    int h) {
905   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
906   const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
907   double sum = 0.0;
908   for (int j = 0; j < h; ++j) {
909     for (int i = 0; i < w; ++i) {
910       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
911       sum += diff;
912     }
913   }
914   assert(w > 0 && h > 0);
915   return sum / (w * h);
916 }
917 
get_diff_mean(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int w,int h)918 static double get_diff_mean(const uint8_t *src, int src_stride,
919                             const uint8_t *dst, int dst_stride, int w, int h) {
920   double sum = 0.0;
921   for (int j = 0; j < h; ++j) {
922     for (int i = 0; i < w; ++i) {
923       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
924       sum += diff;
925     }
926   }
927   assert(w > 0 && h > 0);
928   return sum / (w * h);
929 }
930 
PrintPredictionUnitStats(const AV1_COMP * const cpi,const TileDataEnc * tile_data,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)931 static AOM_INLINE void PrintPredictionUnitStats(const AV1_COMP *const cpi,
932                                                 const TileDataEnc *tile_data,
933                                                 MACROBLOCK *x,
934                                                 const RD_STATS *const rd_stats,
935                                                 BLOCK_SIZE plane_bsize) {
936   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
937 
938   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
939       (tile_data == NULL ||
940        !tile_data->inter_mode_rd_models[plane_bsize].ready))
941     return;
942   (void)tile_data;
943   // Generate small sample to restrict output size.
944   static unsigned int seed = 95014;
945 
946   if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
947       1)
948     return;
949 
950   const char output_file[] = "pu_stats.txt";
951   FILE *fout = fopen(output_file, "a");
952   if (!fout) return;
953 
954   MACROBLOCKD *const xd = &x->e_mbd;
955   const int plane = 0;
956   struct macroblock_plane *const p = &x->plane[plane];
957   struct macroblockd_plane *pd = &xd->plane[plane];
958   const int diff_stride = block_size_wide[plane_bsize];
959   int bw, bh;
960   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
961                      &bh);
962   const int num_samples = bw * bh;
963   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
964   const int q_step = p->dequant_QTX[1] >> dequant_shift;
965   const int shift = (xd->bd - 8);
966 
967   const double rate_norm = (double)rd_stats->rate / num_samples;
968   const double dist_norm = (double)rd_stats->dist / num_samples;
969   const double rdcost_norm =
970       (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
971 
972   fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
973 
974   const int src_stride = p->src.stride;
975   const uint8_t *const src = p->src.buf;
976   const int dst_stride = pd->dst.stride;
977   const uint8_t *const dst = pd->dst.buf;
978   const int16_t *const src_diff = p->src_diff;
979 
980   int64_t sse = calculate_sse(xd, p, pd, bw, bh);
981   const double sse_norm = (double)sse / num_samples;
982 
983   const unsigned int sad =
984       cpi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
985   const double sad_norm =
986       (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
987 
988   fprintf(fout, " %g %g", sse_norm, sad_norm);
989 
990   double sse_norm_arr[4], sad_norm_arr[4];
991   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
992                                    dst_stride, src_diff, diff_stride,
993                                    sse_norm_arr, sad_norm_arr);
994   if (shift) {
995     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
996     for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
997   }
998   for (int i = 0; i < 4; ++i) {
999     fprintf(fout, " %g", sse_norm_arr[i]);
1000   }
1001   for (int i = 0; i < 4; ++i) {
1002     fprintf(fout, " %g", sad_norm_arr[i]);
1003   }
1004 
1005   fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
1006 
1007   int model_rate;
1008   int64_t model_dist;
1009   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
1010                                    &model_rate, &model_dist);
1011   const double model_rdcost_norm =
1012       (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
1013   const double model_rate_norm = (double)model_rate / num_samples;
1014   const double model_dist_norm = (double)model_dist / num_samples;
1015   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
1016           model_rdcost_norm);
1017 
1018   double mean;
1019   if (is_cur_buf_hbd(xd)) {
1020     mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
1021                                 pd->dst.stride, bw, bh);
1022   } else {
1023     mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
1024                          bw, bh);
1025   }
1026   mean /= (1 << shift);
1027   float hor_corr, vert_corr;
1028   av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
1029                                   &vert_corr);
1030   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
1031 
1032   double hdist[4] = { 0 }, vdist[4] = { 0 };
1033   get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
1034                                dst_stride, 1, hdist, vdist);
1035   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
1036           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
1037 
1038   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
1039     assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
1040     const int64_t overall_sse = get_sse(cpi, x);
1041     int est_residue_cost = 0;
1042     int64_t est_dist = 0;
1043     get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
1044                       &est_dist);
1045     const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
1046     const double est_dist_norm = (double)est_dist / num_samples;
1047     const double est_rdcost_norm =
1048         (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
1049     fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
1050             est_rdcost_norm);
1051   }
1052 
1053   fprintf(fout, "\n");
1054   fclose(fout);
1055 }
1056 #endif  // CONFIG_COLLECT_RD_STATS >= 2
1057 #endif  // CONFIG_COLLECT_RD_STATS
1058 
inverse_transform_block_facade(MACROBLOCKD * xd,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)1059 static AOM_INLINE void inverse_transform_block_facade(MACROBLOCKD *xd,
1060                                                       int plane, int block,
1061                                                       int blk_row, int blk_col,
1062                                                       int eob,
1063                                                       int reduced_tx_set) {
1064   if (!eob) return;
1065 
1066   struct macroblockd_plane *const pd = &xd->plane[plane];
1067   tran_low_t *dqcoeff = pd->dqcoeff + BLOCK_OFFSET(block);
1068   const PLANE_TYPE plane_type = get_plane_type(plane);
1069   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
1070   const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
1071                                           tx_size, reduced_tx_set);
1072   const int dst_stride = pd->dst.stride;
1073   uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
1074   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
1075                               dst_stride, eob, reduced_tx_set);
1076 }
1077 
recon_intra(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,int skip_trellis,TX_TYPE best_tx_type,int do_quant,int * rate_cost,uint16_t best_eob)1078 static INLINE void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1079                                int block, int blk_row, int blk_col,
1080                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1081                                const TXB_CTX *const txb_ctx, int skip_trellis,
1082                                TX_TYPE best_tx_type, int do_quant,
1083                                int *rate_cost, uint16_t best_eob) {
1084   const AV1_COMMON *cm = &cpi->common;
1085   MACROBLOCKD *xd = &x->e_mbd;
1086   MB_MODE_INFO *mbmi = xd->mi[0];
1087   const int is_inter = is_inter_block(mbmi);
1088   if (!is_inter && best_eob &&
1089       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
1090        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
1091     // if the quantized coefficients are stored in the dqcoeff buffer, we don't
1092     // need to do transform and quantization again.
1093     if (do_quant) {
1094       TxfmParam txfm_param_intra;
1095       QUANT_PARAM quant_param_intra;
1096       av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
1097       av1_setup_quant(tx_size, !skip_trellis,
1098                       skip_trellis
1099                           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
1100                                                     : AV1_XFORM_QUANT_FP)
1101                           : AV1_XFORM_QUANT_FP,
1102                       cpi->oxcf.quant_b_adapt, &quant_param_intra);
1103       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
1104                         &quant_param_intra);
1105       av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
1106                       &txfm_param_intra, &quant_param_intra);
1107       if (quant_param_intra.use_optimize_b) {
1108         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
1109                        cpi->sf.rd_sf.trellis_eob_fast, rate_cost);
1110       }
1111     }
1112 
1113     inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
1114                                    x->plane[plane].eobs[block],
1115                                    cm->features.reduced_tx_set_used);
1116 
1117     // This may happen because of hash collision. The eob stored in the hash
1118     // table is non-zero, but the real eob is zero. We need to make sure tx_type
1119     // is DCT_DCT in this case.
1120     if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
1121         best_tx_type != DCT_DCT) {
1122       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
1123     }
1124   }
1125 }
1126 
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)1127 static unsigned pixel_dist_visible_only(
1128     const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
1129     const int src_stride, const uint8_t *dst, const int dst_stride,
1130     const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
1131     int visible_cols) {
1132   unsigned sse;
1133 
1134   if (txb_rows == visible_rows && txb_cols == visible_cols) {
1135     cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
1136     return sse;
1137   }
1138 
1139 #if CONFIG_AV1_HIGHBITDEPTH
1140   const MACROBLOCKD *xd = &x->e_mbd;
1141   if (is_cur_buf_hbd(xd)) {
1142     uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
1143                                              visible_cols, visible_rows);
1144     return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
1145   }
1146 #else
1147   (void)x;
1148 #endif
1149   sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
1150                          visible_rows);
1151   return sse;
1152 }
1153 
1154 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
1155 // the
1156 // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)1157 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
1158                            int plane, const uint8_t *src, const int src_stride,
1159                            const uint8_t *dst, const int dst_stride,
1160                            int blk_row, int blk_col,
1161                            const BLOCK_SIZE plane_bsize,
1162                            const BLOCK_SIZE tx_bsize) {
1163   int txb_rows, txb_cols, visible_rows, visible_cols;
1164   const MACROBLOCKD *xd = &x->e_mbd;
1165 
1166   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
1167                      &txb_cols, &txb_rows, &visible_cols, &visible_rows);
1168   assert(visible_rows > 0);
1169   assert(visible_cols > 0);
1170 
1171   unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
1172                                          dst_stride, tx_bsize, txb_rows,
1173                                          txb_cols, visible_rows, visible_cols);
1174 
1175   return sse;
1176 }
1177 
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)1178 static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
1179                                            int plane, BLOCK_SIZE plane_bsize,
1180                                            int block, int blk_row, int blk_col,
1181                                            TX_SIZE tx_size) {
1182   MACROBLOCKD *const xd = &x->e_mbd;
1183   const struct macroblock_plane *const p = &x->plane[plane];
1184   const struct macroblockd_plane *const pd = &xd->plane[plane];
1185   const uint16_t eob = p->eobs[block];
1186   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
1187   const int bsw = block_size_wide[tx_bsize];
1188   const int bsh = block_size_high[tx_bsize];
1189   const int src_stride = x->plane[plane].src.stride;
1190   const int dst_stride = xd->plane[plane].dst.stride;
1191   // Scale the transform block index to pixel unit.
1192   const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
1193   const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
1194   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
1195   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
1196   const tran_low_t *dqcoeff = pd->dqcoeff + BLOCK_OFFSET(block);
1197 
1198   assert(cpi != NULL);
1199   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
1200 
1201   uint8_t *recon;
1202   DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
1203 
1204 #if CONFIG_AV1_HIGHBITDEPTH
1205   if (is_cur_buf_hbd(xd)) {
1206     recon = CONVERT_TO_BYTEPTR(recon16);
1207     av1_highbd_convolve_2d_copy_sr(CONVERT_TO_SHORTPTR(dst), dst_stride,
1208                                    CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw,
1209                                    bsh, NULL, NULL, 0, 0, NULL, xd->bd);
1210   } else {
1211     recon = (uint8_t *)recon16;
1212     av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
1213                             NULL, 0, 0, NULL);
1214   }
1215 #else
1216   recon = (uint8_t *)recon16;
1217   av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
1218                           NULL, 0, 0, NULL);
1219 #endif
1220 
1221   const PLANE_TYPE plane_type = get_plane_type(plane);
1222   TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1223                                     cpi->common.features.reduced_tx_set_used);
1224   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1225                               MAX_TX_SIZE, eob,
1226                               cpi->common.features.reduced_tx_set_used);
1227 
1228   return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1229                          blk_row, blk_col, plane_bsize, tx_bsize);
1230 }
1231 
get_intra_txb_hash(MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size)1232 static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
1233                                    int blk_col, BLOCK_SIZE plane_bsize,
1234                                    TX_SIZE tx_size) {
1235   int16_t tmp_data[64 * 64];
1236   const int diff_stride = block_size_wide[plane_bsize];
1237   const int16_t *diff = x->plane[plane].src_diff;
1238   const int16_t *cur_diff_row = diff + 4 * blk_row * diff_stride + 4 * blk_col;
1239   const int txb_w = tx_size_wide[tx_size];
1240   const int txb_h = tx_size_high[tx_size];
1241   uint8_t *hash_data = (uint8_t *)cur_diff_row;
1242   if (txb_w != diff_stride) {
1243     int16_t *cur_hash_row = tmp_data;
1244     for (int i = 0; i < txb_h; i++) {
1245       memcpy(cur_hash_row, cur_diff_row, sizeof(*diff) * txb_w);
1246       cur_hash_row += txb_w;
1247       cur_diff_row += diff_stride;
1248     }
1249     hash_data = (uint8_t *)tmp_data;
1250   }
1251   CRC32C *crc = &x->mb_rd_record.crc_calculator;
1252   const uint32_t hash = av1_get_crc32c_value(crc, hash_data, 2 * txb_w * txb_h);
1253   return (hash << 5) + tx_size;
1254 }
1255 
1256 // pruning thresholds for prune_txk_type and prune_txk_type_separ
1257 static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
1258 static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
1259 
is_intra_hash_match(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,TXB_RD_INFO ** intra_txb_rd_info,const int tx_type_map_idx,uint16_t * cur_joint_ctx)1260 static INLINE int is_intra_hash_match(const AV1_COMP *cpi, MACROBLOCK *x,
1261                                       int plane, int blk_row, int blk_col,
1262                                       BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1263                                       const TXB_CTX *const txb_ctx,
1264                                       TXB_RD_INFO **intra_txb_rd_info,
1265                                       const int tx_type_map_idx,
1266                                       uint16_t *cur_joint_ctx) {
1267   MACROBLOCKD *xd = &x->e_mbd;
1268   assert(cpi->sf.tx_sf.use_intra_txb_hash &&
1269          frame_is_intra_only(&cpi->common) && !is_inter_block(xd->mi[0]) &&
1270          plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size]);
1271   const uint32_t intra_hash =
1272       get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
1273   const int intra_hash_idx =
1274       find_tx_size_rd_info(&x->txb_rd_record_intra, intra_hash);
1275   *intra_txb_rd_info = &x->txb_rd_record_intra.tx_rd_info[intra_hash_idx];
1276   *cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
1277   if ((*intra_txb_rd_info)->entropy_context == *cur_joint_ctx &&
1278       x->txb_rd_record_intra.tx_rd_info[intra_hash_idx].valid) {
1279     xd->tx_type_map[tx_type_map_idx] = (*intra_txb_rd_info)->tx_type;
1280     const TX_TYPE ref_tx_type =
1281         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1282                         cpi->common.features.reduced_tx_set_used);
1283     return (ref_tx_type == (*intra_txb_rd_info)->tx_type);
1284   }
1285   return 0;
1286 }
1287 
1288 // R-D costs are sorted in ascending order.
sort_rd(int64_t rds[],int txk[],int len)1289 static INLINE void sort_rd(int64_t rds[], int txk[], int len) {
1290   int i, j, k;
1291 
1292   for (i = 1; i <= len - 1; ++i) {
1293     for (j = 0; j < i; ++j) {
1294       if (rds[j] > rds[i]) {
1295         int64_t temprd;
1296         int tempi;
1297 
1298         temprd = rds[i];
1299         tempi = txk[i];
1300 
1301         for (k = i; k > j; k--) {
1302           rds[k] = rds[k - 1];
1303           txk[k] = txk[k - 1];
1304         }
1305 
1306         rds[j] = temprd;
1307         txk[j] = tempi;
1308         break;
1309       }
1310     }
1311   }
1312 }
1313 
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int64_t * out_dist,int64_t * out_sse)1314 static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1315                                         TX_SIZE tx_size, int64_t *out_dist,
1316                                         int64_t *out_sse) {
1317   MACROBLOCKD *const xd = &x->e_mbd;
1318   const struct macroblock_plane *const p = &x->plane[plane];
1319   const struct macroblockd_plane *const pd = &xd->plane[plane];
1320   // Transform domain distortion computation is more efficient as it does
1321   // not involve an inverse transform, but it is less accurate.
1322   const int buffer_length = av1_get_max_eob(tx_size);
1323   int64_t this_sse;
1324   // TX-domain results need to shift down to Q2/D10 to match pixel
1325   // domain distortion values which are in Q2^2
1326   int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1327   const int block_offset = BLOCK_OFFSET(block);
1328   tran_low_t *const coeff = p->coeff + block_offset;
1329   tran_low_t *const dqcoeff = pd->dqcoeff + block_offset;
1330 #if CONFIG_AV1_HIGHBITDEPTH
1331   if (is_cur_buf_hbd(xd))
1332     *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
1333                                        xd->bd);
1334   else
1335     *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1336 #else
1337   *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1338 #endif
1339   *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1340   *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1341 }
1342 
prune_txk_type_separ(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,int16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used,int64_t ref_best_rd,int num_sel)1343 uint16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1344                               int block, TX_SIZE tx_size, int blk_row,
1345                               int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1346                               int16_t allowed_tx_mask, int prune_factor,
1347                               const TXB_CTX *const txb_ctx,
1348                               int reduced_tx_set_used, int64_t ref_best_rd,
1349                               int num_sel) {
1350   const AV1_COMMON *cm = &cpi->common;
1351 
1352   int idx;
1353 
1354   int64_t rds_v[4];
1355   int64_t rds_h[4];
1356   int idx_v[4] = { 0, 1, 2, 3 };
1357   int idx_h[4] = { 0, 1, 2, 3 };
1358   int skip_v[4] = { 0 };
1359   int skip_h[4] = { 0 };
1360   const int idx_map[16] = {
1361     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1362     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1363     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1364     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1365   };
1366 
1367   const int sel_pattern_v[16] = {
1368     0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1369   };
1370   const int sel_pattern_h[16] = {
1371     0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1372   };
1373 
1374   QUANT_PARAM quant_param;
1375   TxfmParam txfm_param;
1376   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1377   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.quant_b_adapt,
1378                   &quant_param);
1379   int tx_type;
1380   // to ensure we can try ones even outside of ext_tx_set of current block
1381   // this function should only be called for size < 16
1382   assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1383   txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1384 
1385   int rate_cost = 0;
1386   int64_t dist = 0, sse = 0;
1387   // evaluate horizontal with vertical DCT
1388   for (idx = 0; idx < 4; ++idx) {
1389     tx_type = idx_map[idx];
1390     txfm_param.tx_type = tx_type;
1391 
1392     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1393                     &quant_param);
1394 
1395     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1396 
1397     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1398                                               txb_ctx, reduced_tx_set_used, 0);
1399 
1400     rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1401 
1402     if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1403       skip_h[idx] = 1;
1404     }
1405   }
1406   sort_rd(rds_h, idx_h, 4);
1407   for (idx = 1; idx < 4; idx++) {
1408     if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1409   }
1410 
1411   if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1412 
1413   // evaluate vertical with the best horizontal chosen
1414   rds_v[0] = rds_h[0];
1415   int start_v = 1, end_v = 4;
1416   const int *idx_map_v = idx_map + idx_h[0];
1417 
1418   for (idx = start_v; idx < end_v; ++idx) {
1419     tx_type = idx_map_v[idx_v[idx] * 4];
1420     txfm_param.tx_type = tx_type;
1421 
1422     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1423                     &quant_param);
1424 
1425     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1426 
1427     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1428                                               txb_ctx, reduced_tx_set_used, 0);
1429 
1430     rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1431 
1432     if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1433       skip_v[idx] = 1;
1434     }
1435   }
1436   sort_rd(rds_v, idx_v, 4);
1437   for (idx = 1; idx < 4; idx++) {
1438     if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1439   }
1440 
1441   // combine rd_h and rd_v to prune tx candidates
1442   int i_v, i_h;
1443   int64_t rds[16];
1444   int num_cand = 0, last = TX_TYPES - 1;
1445 
1446   for (int i = 0; i < 16; i++) {
1447     i_v = sel_pattern_v[i];
1448     i_h = sel_pattern_h[i];
1449     tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1450     if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1451         skip_v[idx_v[i_v]]) {
1452       txk_map[last] = tx_type;
1453       last--;
1454     } else {
1455       txk_map[num_cand] = tx_type;
1456       rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1457       if (rds[num_cand] == 0) rds[num_cand] = 1;
1458       num_cand++;
1459     }
1460   }
1461   sort_rd(rds, txk_map, num_cand);
1462 
1463   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1464   num_sel = AOMMIN(num_sel, num_cand);
1465 
1466   for (int i = 1; i < num_sel; i++) {
1467     int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1468     if (factor < (int64_t)prune_factor)
1469       prune &= ~(1 << txk_map[i]);
1470     else
1471       break;
1472   }
1473   return prune;
1474 }
1475 
prune_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,uint16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1476 uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1477                         int block, TX_SIZE tx_size, int blk_row, int blk_col,
1478                         BLOCK_SIZE plane_bsize, int *txk_map,
1479                         uint16_t allowed_tx_mask, int prune_factor,
1480                         const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
1481   const AV1_COMMON *cm = &cpi->common;
1482   int tx_type;
1483 
1484   int64_t rds[TX_TYPES];
1485 
1486   int num_cand = 0;
1487   int last = TX_TYPES - 1;
1488 
1489   TxfmParam txfm_param;
1490   QUANT_PARAM quant_param;
1491   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1492   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.quant_b_adapt,
1493                   &quant_param);
1494 
1495   for (int idx = 0; idx < TX_TYPES; idx++) {
1496     tx_type = idx;
1497     int rate_cost = 0;
1498     int64_t dist = 0, sse = 0;
1499     if (!(allowed_tx_mask & (1 << tx_type))) {
1500       txk_map[last] = tx_type;
1501       last--;
1502       continue;
1503     }
1504     txfm_param.tx_type = tx_type;
1505 
1506     // do txfm and quantization
1507     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1508                     &quant_param);
1509     // estimate rate cost
1510     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1511                                               txb_ctx, reduced_tx_set_used, 0);
1512     // tx domain dist
1513     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1514 
1515     txk_map[num_cand] = tx_type;
1516     rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1517     if (rds[num_cand] == 0) rds[num_cand] = 1;
1518     num_cand++;
1519   }
1520 
1521   if (num_cand == 0) return (uint16_t)0xFFFF;
1522 
1523   sort_rd(rds, txk_map, num_cand);
1524   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1525 
1526   // 0 < prune_factor <= 1000 controls aggressiveness
1527   int64_t factor = 0;
1528   for (int idx = 1; idx < num_cand; idx++) {
1529     factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1530     if (factor < (int64_t)prune_factor)
1531       prune &= ~(1 << txk_map[idx]);
1532     else
1533       break;
1534   }
1535   return prune;
1536 }
1537 
1538 // These thresholds were calibrated to provide a certain number of TX types
1539 // pruned by the model on average, i.e. selecting a threshold with index i
1540 // will lead to pruning i+1 TX types on average
1541 static const float *prune_2D_adaptive_thresholds[] = {
1542   // TX_4X4
1543   (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1544              0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1545              0.09778f, 0.11780f },
1546   // TX_8X8
1547   (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1548              0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1549              0.10803f, 0.14124f },
1550   // TX_16X16
1551   (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1552              0.06897f, 0.07629f, 0.08875f, 0.11169f },
1553   // TX_32X32
1554   NULL,
1555   // TX_64X64
1556   NULL,
1557   // TX_4X8
1558   (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1559              0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1560              0.10168f, 0.12585f },
1561   // TX_8X4
1562   (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1563              0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1564              0.10583f, 0.13123f },
1565   // TX_8X16
1566   (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1567              0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1568              0.10730f, 0.14221f },
1569   // TX_16X8
1570   (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1571              0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1572              0.10339f, 0.13464f },
1573   // TX_16X32
1574   NULL,
1575   // TX_32X16
1576   NULL,
1577   // TX_32X64
1578   NULL,
1579   // TX_64X32
1580   NULL,
1581   // TX_4X16
1582   (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1583              0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1584              0.10242f, 0.12878f },
1585   // TX_16X4
1586   (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1587              0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1588              0.10217f, 0.12610f },
1589   // TX_8X32
1590   NULL,
1591   // TX_32X8
1592   NULL,
1593   // TX_16X64
1594   NULL,
1595   // TX_64X16
1596   NULL,
1597 };
1598 
1599 // Probablities are sorted in descending order.
sort_probability(float prob[],int txk[],int len)1600 static INLINE void sort_probability(float prob[], int txk[], int len) {
1601   int i, j, k;
1602 
1603   for (i = 1; i <= len - 1; ++i) {
1604     for (j = 0; j < i; ++j) {
1605       if (prob[j] < prob[i]) {
1606         float temp;
1607         int tempi;
1608 
1609         temp = prob[i];
1610         tempi = txk[i];
1611 
1612         for (k = i; k > j; k--) {
1613           prob[k] = prob[k - 1];
1614           txk[k] = txk[k - 1];
1615         }
1616 
1617         prob[j] = temp;
1618         txk[j] = tempi;
1619         break;
1620       }
1621     }
1622   }
1623 }
1624 
get_adaptive_thresholds(TX_SIZE tx_size,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_mode)1625 static INLINE float get_adaptive_thresholds(TX_SIZE tx_size,
1626                                             TxSetType tx_set_type,
1627                                             TX_TYPE_PRUNE_MODE prune_mode) {
1628   const int prune_aggr_table[4][2] = { { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 } };
1629   int pruning_aggressiveness = 0;
1630   if (tx_set_type == EXT_TX_SET_ALL16)
1631     pruning_aggressiveness =
1632         prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][0];
1633   else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1634     pruning_aggressiveness =
1635         prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][1];
1636 
1637   return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1638 }
1639 
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1640 static AOM_INLINE void get_energy_distribution_finer(const int16_t *diff,
1641                                                      int stride, int bw, int bh,
1642                                                      float *hordist,
1643                                                      float *verdist) {
1644   // First compute downscaled block energy values (esq); downscale factors
1645   // are defined by w_shift and h_shift.
1646   unsigned int esq[256];
1647   const int w_shift = bw <= 8 ? 0 : 1;
1648   const int h_shift = bh <= 8 ? 0 : 1;
1649   const int esq_w = bw >> w_shift;
1650   const int esq_h = bh >> h_shift;
1651   const int esq_sz = esq_w * esq_h;
1652   int i, j;
1653   memset(esq, 0, esq_sz * sizeof(esq[0]));
1654   if (w_shift) {
1655     for (i = 0; i < bh; i++) {
1656       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1657       const int16_t *cur_diff_row = diff + i * stride;
1658       for (j = 0; j < bw; j += 2) {
1659         cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1660                                 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1661       }
1662     }
1663   } else {
1664     for (i = 0; i < bh; i++) {
1665       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1666       const int16_t *cur_diff_row = diff + i * stride;
1667       for (j = 0; j < bw; j++) {
1668         cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1669       }
1670     }
1671   }
1672 
1673   uint64_t total = 0;
1674   for (i = 0; i < esq_sz; i++) total += esq[i];
1675 
1676   // Output hordist and verdist arrays are normalized 1D projections of esq
1677   if (total == 0) {
1678     float hor_val = 1.0f / esq_w;
1679     for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1680     float ver_val = 1.0f / esq_h;
1681     for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1682     return;
1683   }
1684 
1685   const float e_recip = 1.0f / (float)total;
1686   memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1687   memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1688   const unsigned int *cur_esq_row;
1689   for (i = 0; i < esq_h - 1; i++) {
1690     cur_esq_row = esq + i * esq_w;
1691     for (j = 0; j < esq_w - 1; j++) {
1692       hordist[j] += (float)cur_esq_row[j];
1693       verdist[i] += (float)cur_esq_row[j];
1694     }
1695     verdist[i] += (float)cur_esq_row[j];
1696   }
1697   cur_esq_row = esq + i * esq_w;
1698   for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1699 
1700   for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1701   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1702 }
1703 
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_mode,int * txk_map,uint16_t * allowed_tx_mask)1704 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1705                         int blk_row, int blk_col, TxSetType tx_set_type,
1706                         TX_TYPE_PRUNE_MODE prune_mode, int *txk_map,
1707                         uint16_t *allowed_tx_mask) {
1708   int tx_type_table_2D[16] = {
1709     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1710     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1711     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1712     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1713   };
1714   if (tx_set_type != EXT_TX_SET_ALL16 &&
1715       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1716     return;
1717 #if CONFIG_NN_V2
1718   NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1719   NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1720 #else
1721   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1722   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1723 #endif
1724   if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
1725 
1726   aom_clear_system_state();
1727   float hfeatures[16], vfeatures[16];
1728   float hscores[4], vscores[4];
1729   float scores_2D_raw[16];
1730   float scores_2D[16];
1731   const int bw = tx_size_wide[tx_size];
1732   const int bh = tx_size_high[tx_size];
1733   const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1734   const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1735   assert(hfeatures_num <= 16);
1736   assert(vfeatures_num <= 16);
1737 
1738   const struct macroblock_plane *const p = &x->plane[0];
1739   const int diff_stride = block_size_wide[bsize];
1740   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1741   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1742                                 vfeatures);
1743   av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1744                                   &hfeatures[hfeatures_num - 1],
1745                                   &vfeatures[vfeatures_num - 1]);
1746   aom_clear_system_state();
1747 #if CONFIG_NN_V2
1748   av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1749   av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1750 #else
1751   av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1752   av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1753 #endif
1754   aom_clear_system_state();
1755 
1756   for (int i = 0; i < 4; i++) {
1757     float *cur_scores_2D = scores_2D_raw + i * 4;
1758     cur_scores_2D[0] = vscores[i] * hscores[0];
1759     cur_scores_2D[1] = vscores[i] * hscores[1];
1760     cur_scores_2D[2] = vscores[i] * hscores[2];
1761     cur_scores_2D[3] = vscores[i] * hscores[3];
1762   }
1763 
1764   av1_nn_softmax(scores_2D_raw, scores_2D, 16);
1765 
1766   const float score_thresh =
1767       get_adaptive_thresholds(tx_size, tx_set_type, prune_mode);
1768 
1769   // Always keep the TX type with the highest score, prune all others with
1770   // score below score_thresh.
1771   int max_score_i = 0;
1772   float max_score = 0.0f;
1773   uint16_t allow_bitmask = 0;
1774   float sum_score = 0.0;
1775   // Calculate sum of allowed tx type score and Populate allow bit mask based
1776   // on score_thresh and allowed_tx_mask
1777   for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1778     int allow_tx_type = *allowed_tx_mask & (1 << tx_type_table_2D[tx_idx]);
1779     if (scores_2D[tx_idx] > max_score && allow_tx_type) {
1780       max_score = scores_2D[tx_idx];
1781       max_score_i = tx_idx;
1782     }
1783     if (scores_2D[tx_idx] >= score_thresh && allow_tx_type) {
1784       // Set allow mask based on score_thresh
1785       allow_bitmask |= (1 << tx_type_table_2D[tx_idx]);
1786 
1787       // Accumulate score of allowed tx type
1788       sum_score += scores_2D[tx_idx];
1789     }
1790   }
1791   if (!((allow_bitmask >> max_score_i) & 0x01)) {
1792     // Set allow mask based on tx type with max score
1793     allow_bitmask |= (1 << tx_type_table_2D[max_score_i]);
1794     sum_score += scores_2D[max_score_i];
1795   }
1796   // Sort tx type probability of all types
1797   sort_probability(scores_2D, tx_type_table_2D, TX_TYPES);
1798 
1799   // Enable more pruning based on tx type probability and number of allowed tx
1800   // types
1801   if (prune_mode == PRUNE_2D_AGGRESSIVE) {
1802     float temp_score = 0.0;
1803     float score_ratio = 0.0;
1804     int tx_idx, tx_count = 0;
1805     const float inv_sum_score = 100 / sum_score;
1806     // Get allowed tx types based on sorted probability score and tx count
1807     for (tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1808       // Skip the tx type which has more than 30% of cumulative
1809       // probability and allowed tx type count is more than 2
1810       if (score_ratio > 30.0 && tx_count >= 2) break;
1811 
1812       // Calculate cumulative probability of allowed tx types
1813       if (allow_bitmask & (1 << tx_type_table_2D[tx_idx])) {
1814         // Calculate cumulative probability
1815         temp_score += scores_2D[tx_idx];
1816 
1817         // Calculate percentage of cumulative probability of allowed tx type
1818         score_ratio = temp_score * inv_sum_score;
1819         tx_count++;
1820       }
1821     }
1822     // Set remaining tx types as pruned
1823     for (; tx_idx < TX_TYPES; tx_idx++)
1824       allow_bitmask &= ~(1 << tx_type_table_2D[tx_idx]);
1825   }
1826   memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1827   *allowed_tx_mask = allow_bitmask;
1828 }
1829 
get_dev(float mean,double x2_sum,int num)1830 static float get_dev(float mean, double x2_sum, int num) {
1831   const float e_x2 = (float)(x2_sum / num);
1832   const float diff = e_x2 - mean * mean;
1833   const float dev = (diff > 0) ? sqrtf(diff) : 0;
1834   return dev;
1835 }
1836 
1837 // Feature used by the model to predict tx split: the mean and standard
1838 // deviation values of the block and sub-blocks.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,float * feature)1839 static AOM_INLINE void get_mean_dev_features(const int16_t *data, int stride,
1840                                              int bw, int bh, float *feature) {
1841   const int16_t *const data_ptr = &data[0];
1842   const int subh = (bh >= bw) ? (bh >> 1) : bh;
1843   const int subw = (bw >= bh) ? (bw >> 1) : bw;
1844   const int num = bw * bh;
1845   const int sub_num = subw * subh;
1846   int feature_idx = 2;
1847   int total_x_sum = 0;
1848   int64_t total_x2_sum = 0;
1849   int blk_idx = 0;
1850   double mean2_sum = 0.0f;
1851   float dev_sum = 0.0f;
1852 
1853   for (int row = 0; row < bh; row += subh) {
1854     for (int col = 0; col < bw; col += subw) {
1855       int x_sum;
1856       int64_t x2_sum;
1857       // TODO(any): Write a SIMD version. Clear registers.
1858       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1859                           &x_sum, &x2_sum);
1860       total_x_sum += x_sum;
1861       total_x2_sum += x2_sum;
1862 
1863       aom_clear_system_state();
1864       const float mean = (float)x_sum / sub_num;
1865       const float dev = get_dev(mean, (double)x2_sum, sub_num);
1866       feature[feature_idx++] = mean;
1867       feature[feature_idx++] = dev;
1868       mean2_sum += (double)(mean * mean);
1869       dev_sum += dev;
1870       blk_idx++;
1871     }
1872   }
1873 
1874   const float lvl0_mean = (float)total_x_sum / num;
1875   feature[0] = lvl0_mean;
1876   feature[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1877 
1878   if (blk_idx > 1) {
1879     // Deviation of means.
1880     feature[feature_idx++] = get_dev(lvl0_mean, mean2_sum, blk_idx);
1881     // Mean of deviations.
1882     feature[feature_idx++] = dev_sum / blk_idx;
1883   }
1884 }
1885 
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)1886 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1887                                int blk_col, TX_SIZE tx_size) {
1888   const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1889   if (!nn_config) return -1;
1890 
1891   const int diff_stride = block_size_wide[bsize];
1892   const int16_t *diff =
1893       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1894   const int bw = tx_size_wide[tx_size];
1895   const int bh = tx_size_high[tx_size];
1896   aom_clear_system_state();
1897 
1898   float features[64] = { 0.0f };
1899   get_mean_dev_features(diff, diff_stride, bw, bh, features);
1900 
1901   float score = 0.0f;
1902   av1_nn_predict(features, nn_config, 1, &score);
1903   aom_clear_system_state();
1904 
1905   int int_score = (int)(score * 10000);
1906   return clamp(int_score, -80000, 80000);
1907 }
1908 
1909 static INLINE uint16_t
get_tx_mask(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_best_rd,TX_TYPE * allowed_txk_types,int * txk_map)1910 get_tx_mask(const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block,
1911             int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1912             const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1913             int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1914   const AV1_COMMON *cm = &cpi->common;
1915   MACROBLOCKD *xd = &x->e_mbd;
1916   MB_MODE_INFO *mbmi = xd->mi[0];
1917   const int is_inter = is_inter_block(mbmi);
1918   const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1919   // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1920   // TX_TYPES, only that specific tx type is allowed.
1921   TX_TYPE txk_allowed = TX_TYPES;
1922 
1923   if ((!is_inter && x->use_default_intra_tx_type) ||
1924       (is_inter && x->use_default_inter_tx_type)) {
1925     txk_allowed =
1926         get_default_tx_type(0, xd, tx_size, cpi->is_screen_content_type);
1927   } else if (x->rd_model == LOW_TXFM_RD) {
1928     if (plane == 0) txk_allowed = DCT_DCT;
1929   }
1930 
1931   const TxSetType tx_set_type = av1_get_ext_tx_set_type(
1932       tx_size, is_inter, cm->features.reduced_tx_set_used);
1933 
1934   TX_TYPE uv_tx_type = DCT_DCT;
1935   if (plane) {
1936     // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
1937     uv_tx_type = txk_allowed =
1938         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1939                         cm->features.reduced_tx_set_used);
1940   }
1941   PREDICTION_MODE intra_dir =
1942       mbmi->filter_intra_mode_info.use_filter_intra
1943           ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
1944           : mbmi->mode;
1945   uint16_t ext_tx_used_flag =
1946       cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset &&
1947               tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
1948           ? av1_reduced_intra_tx_used_flag[intra_dir]
1949           : av1_ext_tx_used_flag[tx_set_type];
1950   if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
1951       ext_tx_used_flag == 0x0001 ||
1952       (is_inter && cpi->oxcf.use_inter_dct_only) ||
1953       (!is_inter && cpi->oxcf.use_intra_dct_only)) {
1954     txk_allowed = DCT_DCT;
1955   }
1956 
1957   if (cpi->oxcf.enable_flip_idtx == 0) ext_tx_used_flag &= DCT_ADST_TX_MASK;
1958 
1959   uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
1960   if (txk_allowed < TX_TYPES) {
1961     allowed_tx_mask = 1 << txk_allowed;
1962     allowed_tx_mask &= ext_tx_used_flag;
1963   } else if (fast_tx_search) {
1964     allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
1965     allowed_tx_mask &= ext_tx_used_flag;
1966   } else {
1967     assert(plane == 0);
1968     allowed_tx_mask = ext_tx_used_flag;
1969     int num_allowed = 0;
1970     const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
1971     const int *tx_type_probs =
1972         cpi->frame_probs.tx_type_probs[update_type][tx_size];
1973     int i;
1974 
1975     if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
1976       static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
1977                                             { 10, 17, 17, 10, 17, 17, 17 } };
1978       const int thresh =
1979           thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
1980                     [update_type];
1981       uint16_t prune = 0;
1982       int max_prob = -1;
1983       int max_idx = 0;
1984       for (i = 0; i < TX_TYPES; i++) {
1985         if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
1986           max_prob = tx_type_probs[i];
1987           max_idx = i;
1988         }
1989         if (tx_type_probs[i] < thresh) prune |= (1 << i);
1990       }
1991       if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
1992       allowed_tx_mask &= (~prune);
1993     }
1994     for (i = 0; i < TX_TYPES; i++) {
1995       if (allowed_tx_mask & (1 << i)) num_allowed++;
1996     }
1997     assert(num_allowed > 0);
1998 
1999     if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
2000       int pf = prune_factors[x->prune_mode];
2001       int mf = mul_factors[x->prune_mode];
2002       if (num_allowed <= 7) {
2003         const uint16_t prune =
2004             prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
2005                            plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
2006                            cm->features.reduced_tx_set_used);
2007         allowed_tx_mask &= (~prune);
2008       } else {
2009         const int num_sel = (num_allowed * mf + 50) / 100;
2010         const uint16_t prune = prune_txk_type_separ(
2011             cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
2012             txk_map, allowed_tx_mask, pf, txb_ctx,
2013             cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
2014 
2015         allowed_tx_mask &= (~prune);
2016       }
2017     } else {
2018       assert(num_allowed > 0);
2019       int allowed_tx_count = (x->prune_mode == PRUNE_2D_AGGRESSIVE) ? 1 : 5;
2020       // !fast_tx_search && txk_end != txk_start && plane == 0
2021       if (x->prune_mode >= PRUNE_2D_ACCURATE && is_inter &&
2022           num_allowed > allowed_tx_count) {
2023         prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
2024                     x->prune_mode, txk_map, &allowed_tx_mask);
2025       }
2026     }
2027   }
2028 
2029   // Need to have at least one transform type allowed.
2030   if (allowed_tx_mask == 0) {
2031     txk_allowed = (plane ? uv_tx_type : DCT_DCT);
2032     allowed_tx_mask = (1 << txk_allowed);
2033   }
2034 
2035   assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
2036   *allowed_txk_types = txk_allowed;
2037   return allowed_tx_mask;
2038 }
2039 
2040 #if CONFIG_RD_DEBUG
update_txb_coeff_cost(RD_STATS * rd_stats,int plane,TX_SIZE tx_size,int blk_row,int blk_col,int txb_coeff_cost)2041 static INLINE void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
2042                                          TX_SIZE tx_size, int blk_row,
2043                                          int blk_col, int txb_coeff_cost) {
2044   (void)blk_row;
2045   (void)blk_col;
2046   (void)tx_size;
2047   rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
2048 
2049   {
2050     const int txb_h = tx_size_high_unit[tx_size];
2051     const int txb_w = tx_size_wide_unit[tx_size];
2052     int idx, idy;
2053     for (idy = 0; idy < txb_h; ++idy)
2054       for (idx = 0; idx < txb_w; ++idx)
2055         rd_stats->txb_coeff_cost_map[plane][blk_row + idy][blk_col + idx] = 0;
2056 
2057     rd_stats->txb_coeff_cost_map[plane][blk_row][blk_col] = txb_coeff_cost;
2058   }
2059   assert(blk_row < TXB_COEFF_COST_MAP_SIZE);
2060   assert(blk_col < TXB_COEFF_COST_MAP_SIZE);
2061 }
2062 #endif
2063 
cost_coeffs(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const TX_TYPE tx_type,const TXB_CTX * const txb_ctx,int use_fast_coef_costing,int reduced_tx_set_used)2064 static INLINE int cost_coeffs(MACROBLOCK *x, int plane, int block,
2065                               TX_SIZE tx_size, const TX_TYPE tx_type,
2066                               const TXB_CTX *const txb_ctx,
2067                               int use_fast_coef_costing,
2068                               int reduced_tx_set_used) {
2069 #if TXCOEFF_COST_TIMER
2070   struct aom_usec_timer timer;
2071   aom_usec_timer_start(&timer);
2072 #endif
2073   (void)use_fast_coef_costing;
2074   const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
2075                                        txb_ctx, reduced_tx_set_used);
2076 #if TXCOEFF_COST_TIMER
2077   AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
2078   aom_usec_timer_mark(&timer);
2079   const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
2080   tmp_cm->txcoeff_cost_timer += elapsed_time;
2081   ++tmp_cm->txcoeff_cost_count;
2082 #endif
2083   return cost;
2084 }
2085 
2086 // Search for the best transform type for a given transform block.
2087 // This function can be used for both inter and intra, both luma and chroma.
search_tx_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int use_fast_coef_costing,int skip_trellis,int64_t ref_best_rd,RD_STATS * best_rd_stats)2088 static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2089                            int block, int blk_row, int blk_col,
2090                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2091                            const TXB_CTX *const txb_ctx,
2092                            FAST_TX_SEARCH_MODE ftxs_mode,
2093                            int use_fast_coef_costing, int skip_trellis,
2094                            int64_t ref_best_rd, RD_STATS *best_rd_stats) {
2095   const AV1_COMMON *cm = &cpi->common;
2096   MACROBLOCKD *xd = &x->e_mbd;
2097   struct macroblockd_plane *const pd = &xd->plane[plane];
2098   MB_MODE_INFO *mbmi = xd->mi[0];
2099   int64_t best_rd = INT64_MAX;
2100   uint16_t best_eob = 0;
2101   TX_TYPE best_tx_type = DCT_DCT;
2102   int rate_cost = 0;
2103   // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
2104   // of the best tx_type
2105   DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
2106   tran_low_t *orig_dqcoeff = pd->dqcoeff;
2107   tran_low_t *best_dqcoeff = this_dqcoeff;
2108   const int tx_type_map_idx =
2109       plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2110   av1_invalid_rd_stats(best_rd_stats);
2111 
2112   skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
2113                                    DRY_RUN_NORMAL);
2114 
2115   // Hashing based speed feature for intra block. If the hash of the residue
2116   // is found in the hash table, use the previous RD search results stored in
2117   // the table and terminate early.
2118   TXB_RD_INFO *intra_txb_rd_info = NULL;
2119   uint16_t cur_joint_ctx = 0;
2120   const int is_inter = is_inter_block(mbmi);
2121   const int use_intra_txb_hash =
2122       cpi->sf.tx_sf.use_intra_txb_hash && frame_is_intra_only(cm) &&
2123       !is_inter && plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size];
2124   if (use_intra_txb_hash) {
2125     const int mi_row = xd->mi_row;
2126     const int mi_col = xd->mi_col;
2127     const int within_border =
2128         mi_row >= xd->tile.mi_row_start &&
2129         (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
2130         mi_col >= xd->tile.mi_col_start &&
2131         (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
2132     if (within_border &&
2133         is_intra_hash_match(cpi, x, plane, blk_row, blk_col, plane_bsize,
2134                             tx_size, txb_ctx, &intra_txb_rd_info,
2135                             tx_type_map_idx, &cur_joint_ctx)) {
2136       best_rd_stats->rate = intra_txb_rd_info->rate;
2137       best_rd_stats->dist = intra_txb_rd_info->dist;
2138       best_rd_stats->sse = intra_txb_rd_info->sse;
2139       best_rd_stats->skip = intra_txb_rd_info->eob == 0;
2140       x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
2141       x->plane[plane].txb_entropy_ctx[block] =
2142           intra_txb_rd_info->txb_entropy_ctx;
2143       best_eob = intra_txb_rd_info->eob;
2144       best_tx_type = intra_txb_rd_info->tx_type;
2145       skip_trellis |= !intra_txb_rd_info->perform_block_coeff_opt;
2146       update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2147       recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2148                   txb_ctx, skip_trellis, best_tx_type, 1, &rate_cost, best_eob);
2149       pd->dqcoeff = orig_dqcoeff;
2150       return;
2151     }
2152   }
2153 
2154   uint8_t best_txb_ctx = 0;
2155   // txk_allowed = TX_TYPES: >1 tx types are allowed
2156   // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2157   TX_TYPE txk_allowed = TX_TYPES;
2158   int txk_map[TX_TYPES] = {
2159     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2160   };
2161   // Bit mask to indicate which transform types are allowed in the RD search.
2162   const uint16_t allowed_tx_mask =
2163       get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2164                   txb_ctx, ftxs_mode, ref_best_rd, &txk_allowed, txk_map);
2165 
2166   unsigned int block_mse_q8;
2167   int64_t block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2168                                       txsize_to_bsize[tx_size], &block_mse_q8);
2169   assert(block_mse_q8 != UINT_MAX);
2170   if (is_cur_buf_hbd(xd)) {
2171     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2172     block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2173   }
2174   block_sse *= 16;
2175   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2176   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2177   // Use mse / qstep^2 based threshold logic to take decision of R-D
2178   // optimization of coeffs. For smaller residuals, coeff optimization
2179   // would be helpful. For larger residuals, R-D optimization may not be
2180   // effective.
2181   // TODO(any): Experiment with variance and mean based thresholds
2182   const int perform_block_coeff_opt =
2183       ((uint64_t)block_mse_q8 <=
2184        (uint64_t)x->coeff_opt_dist_threshold * qstep * qstep);
2185   skip_trellis |= !perform_block_coeff_opt;
2186 
2187   // Flag to indicate if distortion should be calculated in transform domain or
2188   // not during iterating through transform type candidates.
2189   // Transform domain distortion is accurate for higher residuals.
2190   // TODO(any): Experiment with variance and mean based thresholds
2191   int use_transform_domain_distortion =
2192       (x->use_transform_domain_distortion > 0) &&
2193       (block_mse_q8 >= x->tx_domain_dist_threshold) &&
2194       // Any 64-pt transforms only preserves half the coefficients.
2195       // Therefore transform domain distortion is not valid for these
2196       // transform sizes.
2197       txsize_sqr_up_map[tx_size] != TX_64X64;
2198   // Flag to indicate if an extra calculation of distortion in the pixel domain
2199   // should be performed at the end, after the best transform type has been
2200   // decided.
2201   int calc_pixel_domain_distortion_final =
2202       x->use_transform_domain_distortion == 1 &&
2203       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2204   if (calc_pixel_domain_distortion_final &&
2205       (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2206     calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2207 
2208   const uint16_t *eobs_ptr = x->plane[plane].eobs;
2209 
2210   TxfmParam txfm_param;
2211   QUANT_PARAM quant_param;
2212   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2213   av1_setup_quant(tx_size, !skip_trellis,
2214                   skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2215                                                          : AV1_XFORM_QUANT_FP)
2216                                : AV1_XFORM_QUANT_FP,
2217                   cpi->oxcf.quant_b_adapt, &quant_param);
2218 
2219   // Iterate through all transform type candidates.
2220   for (int idx = 0; idx < TX_TYPES; ++idx) {
2221     const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2222     if (!(allowed_tx_mask & (1 << tx_type))) continue;
2223     txfm_param.tx_type = tx_type;
2224     if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2225       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2226                         &quant_param);
2227     }
2228     if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2229     RD_STATS this_rd_stats;
2230     av1_invalid_rd_stats(&this_rd_stats);
2231 
2232     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
2233                     &quant_param);
2234 
2235     // Calculate rate cost of quantized coefficients.
2236     if (quant_param.use_optimize_b) {
2237       if (cpi->sf.rd_sf.optimize_b_precheck && best_rd < INT64_MAX &&
2238           eobs_ptr[block] >= 4) {
2239         // Calculate distortion quickly in transform domain.
2240         dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
2241                              &this_rd_stats.sse);
2242 
2243         const int64_t best_rd_ = AOMMIN(best_rd, ref_best_rd);
2244         const int64_t dist_cost_estimate =
2245             RDCOST(x->rdmult, 0, AOMMIN(this_rd_stats.dist, this_rd_stats.sse));
2246         if (dist_cost_estimate - (dist_cost_estimate >> 3) > best_rd_) continue;
2247       }
2248       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2249                      cpi->sf.rd_sf.trellis_eob_fast, &rate_cost);
2250     } else {
2251       rate_cost =
2252           cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2253                       use_fast_coef_costing, cm->features.reduced_tx_set_used);
2254     }
2255 
2256     // If rd cost based on coeff rate alone is already more than best_rd,
2257     // terminate early.
2258     if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2259 
2260     // Calculate distortion.
2261     if (eobs_ptr[block] == 0) {
2262       // When eob is 0, pixel domain distortion is more efficient and accurate.
2263       this_rd_stats.dist = this_rd_stats.sse = block_sse;
2264     } else if (use_transform_domain_distortion) {
2265       dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
2266                            &this_rd_stats.sse);
2267     } else {
2268       int64_t sse_diff = INT64_MAX;
2269       // high_energy threshold assumes that every pixel within a txfm block
2270       // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2271       // for 8 bit, then the threshold is scaled based on input bit depth.
2272       const int64_t high_energy_thresh =
2273           ((int64_t)128 * 128 * tx_size_2d[tx_size]) << ((xd->bd - 8) * 2);
2274       const int is_high_energy = (block_sse >= high_energy_thresh);
2275       if (tx_size == TX_64X64 || is_high_energy) {
2276         // Because 3 out 4 quadrants of transform coefficients are forced to
2277         // zero, the inverse transform has a tendency to overflow. sse_diff
2278         // is effectively the energy of those 3 quadrants, here we use it
2279         // to decide if we should do pixel domain distortion. If the energy
2280         // is mostly in first quadrant, then it is unlikely that we have
2281         // overflow issue in inverse transform.
2282         dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
2283                              &this_rd_stats.sse);
2284         sse_diff = block_sse - this_rd_stats.sse;
2285       }
2286       if (tx_size != TX_64X64 || !is_high_energy ||
2287           (sse_diff * 2) < this_rd_stats.sse) {
2288         const int64_t tx_domain_dist = this_rd_stats.dist;
2289         this_rd_stats.dist = dist_block_px_domain(
2290             cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2291         // For high energy blocks, occasionally, the pixel domain distortion
2292         // can be artificially low due to clamping at reconstruction stage
2293         // even when inverse transform output is hugely different from the
2294         // actual residue.
2295         if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2296           this_rd_stats.dist = tx_domain_dist;
2297       } else {
2298         assert(sse_diff < INT64_MAX);
2299         this_rd_stats.dist += sse_diff;
2300       }
2301       this_rd_stats.sse = block_sse;
2302     }
2303 
2304     this_rd_stats.rate = rate_cost;
2305 
2306     const int64_t rd =
2307         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2308 
2309     if (rd < best_rd) {
2310       best_rd = rd;
2311       *best_rd_stats = this_rd_stats;
2312       best_tx_type = tx_type;
2313       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2314       best_eob = x->plane[plane].eobs[block];
2315       // Swap dqcoeff buffers
2316       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2317       best_dqcoeff = pd->dqcoeff;
2318       pd->dqcoeff = tmp_dqcoeff;
2319     }
2320 
2321 #if CONFIG_COLLECT_RD_STATS == 1
2322     if (plane == 0) {
2323       PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2324                               plane_bsize, tx_size, tx_type, rd);
2325     }
2326 #endif  // CONFIG_COLLECT_RD_STATS == 1
2327 
2328 #if COLLECT_TX_SIZE_DATA
2329     // Generate small sample to restrict output size.
2330     static unsigned int seed = 21743;
2331     if (lcg_rand16(&seed) % 200 == 0) {
2332       FILE *fp = NULL;
2333 
2334       if (within_border) {
2335         fp = fopen(av1_tx_size_data_output_file, "a");
2336       }
2337 
2338       if (fp) {
2339         // Transform info and RD
2340         const int txb_w = tx_size_wide[tx_size];
2341         const int txb_h = tx_size_high[tx_size];
2342 
2343         // Residue signal.
2344         const int diff_stride = block_size_wide[plane_bsize];
2345         struct macroblock_plane *const p = &x->plane[plane];
2346         const int16_t *src_diff =
2347             &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2348 
2349         for (int r = 0; r < txb_h; ++r) {
2350           for (int c = 0; c < txb_w; ++c) {
2351             fprintf(fp, "%d,", src_diff[c]);
2352           }
2353           src_diff += diff_stride;
2354         }
2355 
2356         fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2357         fprintf(fp, "\n");
2358         fclose(fp);
2359       }
2360     }
2361 #endif  // COLLECT_TX_SIZE_DATA
2362 
2363     // If the current best RD cost is much worse than the reference RD cost,
2364     // terminate early.
2365     if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2366       if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2367           ref_best_rd) {
2368         break;
2369       }
2370     }
2371 
2372     // Terminate transform type search if the block has been quantized to
2373     // all zero.
2374     if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2375   }
2376 
2377   assert(best_rd != INT64_MAX);
2378 
2379   best_rd_stats->skip = best_eob == 0;
2380   if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2381   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2382   x->plane[plane].eobs[block] = best_eob;
2383 
2384   // Point dqcoeff to the quantized coefficients corresponding to the best
2385   // transform type, then we can skip transform and quantization, e.g. in the
2386   // final pixel domain distortion calculation and recon_intra().
2387   pd->dqcoeff = best_dqcoeff;
2388 
2389   if (calc_pixel_domain_distortion_final && best_eob) {
2390     best_rd_stats->dist = dist_block_px_domain(
2391         cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2392     best_rd_stats->sse = block_sse;
2393   }
2394 
2395   if (intra_txb_rd_info != NULL) {
2396     intra_txb_rd_info->valid = 1;
2397     intra_txb_rd_info->entropy_context = cur_joint_ctx;
2398     intra_txb_rd_info->rate = best_rd_stats->rate;
2399     intra_txb_rd_info->dist = best_rd_stats->dist;
2400     intra_txb_rd_info->sse = best_rd_stats->sse;
2401     intra_txb_rd_info->eob = best_eob;
2402     intra_txb_rd_info->txb_entropy_ctx = best_txb_ctx;
2403     intra_txb_rd_info->perform_block_coeff_opt = perform_block_coeff_opt;
2404     if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
2405   }
2406 
2407   // Intra mode needs decoded pixels such that the next transform block
2408   // can use them for prediction.
2409   recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2410               txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2411   pd->dqcoeff = orig_dqcoeff;
2412 }
2413 
2414 // Pick transform type for a luma transform block of tx_size. Note this function
2415 // is used only for inter-predicted blocks.
tx_type_rd(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost,TXB_RD_INFO * rd_info_array)2416 static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2417                                   TX_SIZE tx_size, int blk_row, int blk_col,
2418                                   int block, int plane_bsize, TXB_CTX *txb_ctx,
2419                                   RD_STATS *rd_stats,
2420                                   FAST_TX_SEARCH_MODE ftxs_mode,
2421                                   int64_t ref_rdcost,
2422                                   TXB_RD_INFO *rd_info_array) {
2423   const struct macroblock_plane *const p = &x->plane[0];
2424   const uint16_t cur_joint_ctx =
2425       (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
2426   MACROBLOCKD *xd = &x->e_mbd;
2427   assert(is_inter_block(xd->mi[0]));
2428   const int tx_type_map_idx = blk_row * xd->tx_type_map_stride + blk_col;
2429   // Look up RD and terminate early in case when we've already processed exactly
2430   // the same residue with exactly the same entropy context.
2431   if (rd_info_array != NULL && rd_info_array->valid &&
2432       rd_info_array->entropy_context == cur_joint_ctx) {
2433     xd->tx_type_map[tx_type_map_idx] = rd_info_array->tx_type;
2434     const TX_TYPE ref_tx_type =
2435         av1_get_tx_type(&x->e_mbd, get_plane_type(0), blk_row, blk_col, tx_size,
2436                         cpi->common.features.reduced_tx_set_used);
2437     if (ref_tx_type == rd_info_array->tx_type) {
2438       rd_stats->rate += rd_info_array->rate;
2439       rd_stats->dist += rd_info_array->dist;
2440       rd_stats->sse += rd_info_array->sse;
2441       rd_stats->skip &= rd_info_array->eob == 0;
2442       p->eobs[block] = rd_info_array->eob;
2443       p->txb_entropy_ctx[block] = rd_info_array->txb_entropy_ctx;
2444       return;
2445     }
2446   }
2447 
2448   RD_STATS this_rd_stats;
2449   const int skip_trellis = 0;
2450   search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2451                  txb_ctx, ftxs_mode, 0, skip_trellis, ref_rdcost,
2452                  &this_rd_stats);
2453 
2454   av1_merge_rd_stats(rd_stats, &this_rd_stats);
2455 
2456   // Save RD results for possible reuse in future.
2457   if (rd_info_array != NULL) {
2458     rd_info_array->valid = 1;
2459     rd_info_array->entropy_context = cur_joint_ctx;
2460     rd_info_array->rate = this_rd_stats.rate;
2461     rd_info_array->dist = this_rd_stats.dist;
2462     rd_info_array->sse = this_rd_stats.sse;
2463     rd_info_array->eob = p->eobs[block];
2464     rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
2465     rd_info_array->tx_type = xd->tx_type_map[tx_type_map_idx];
2466   }
2467 }
2468 
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,TxCandidateInfo * no_split)2469 static AOM_INLINE void try_tx_block_no_split(
2470     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2471     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2472     const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2473     int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2474     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
2475     TxCandidateInfo *no_split) {
2476   MACROBLOCKD *const xd = &x->e_mbd;
2477   MB_MODE_INFO *const mbmi = xd->mi[0];
2478   struct macroblock_plane *const p = &x->plane[0];
2479   const int bw = mi_size_wide[plane_bsize];
2480   const ENTROPY_CONTEXT *const pta = ta + blk_col;
2481   const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2482   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2483   TXB_CTX txb_ctx;
2484   get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2485   const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
2486                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2487   rd_stats->zero_rate = zero_blk_rate;
2488   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2489   mbmi->inter_tx_size[index] = tx_size;
2490   tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2491              rd_stats, ftxs_mode, ref_best_rd,
2492              rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
2493   assert(rd_stats->rate < INT_MAX);
2494 
2495   const int pick_skip = !xd->lossless[mbmi->segment_id] &&
2496                         (rd_stats->skip == 1 ||
2497                          RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2498                              RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2499   if (pick_skip) {
2500 #if CONFIG_RD_DEBUG
2501     update_txb_coeff_cost(rd_stats, 0, tx_size, blk_row, blk_col,
2502                           zero_blk_rate - rd_stats->rate);
2503 #endif  // CONFIG_RD_DEBUG
2504     rd_stats->rate = zero_blk_rate;
2505     rd_stats->dist = rd_stats->sse;
2506     p->eobs[block] = 0;
2507     update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2508   }
2509   rd_stats->skip = pick_skip;
2510   set_blk_skip(x, 0, blk_row * bw + blk_col, pick_skip);
2511 
2512   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2513     rd_stats->rate += x->txfm_partition_cost[txfm_partition_ctx][0];
2514 
2515   no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2516   no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2517   no_split->tx_type =
2518       xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2519 }
2520 
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,RD_STATS * split_rd_stats)2521 static AOM_INLINE void try_tx_block_split(
2522     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2523     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2524     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2525     int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2526     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
2527     RD_STATS *split_rd_stats) {
2528   assert(tx_size < TX_SIZES_ALL);
2529   MACROBLOCKD *const xd = &x->e_mbd;
2530   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2531   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2532   const int txb_width = tx_size_wide_unit[tx_size];
2533   const int txb_height = tx_size_high_unit[tx_size];
2534   // Transform size after splitting current block.
2535   const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2536   const int sub_txb_width = tx_size_wide_unit[sub_txs];
2537   const int sub_txb_height = tx_size_high_unit[sub_txs];
2538   const int sub_step = sub_txb_width * sub_txb_height;
2539   const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2540   assert(nblks > 0);
2541   av1_init_rd_stats(split_rd_stats);
2542   split_rd_stats->rate = x->txfm_partition_cost[txfm_partition_ctx][1];
2543 
2544   for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2545     for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2546       assert(blk_idx < 4);
2547       const int offsetr = blk_row + r;
2548       const int offsetc = blk_col + c;
2549       if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
2550 
2551       RD_STATS this_rd_stats;
2552       int this_cost_valid = 1;
2553       select_tx_block(
2554           cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
2555           tl, tx_above, tx_left, &this_rd_stats, no_split_rd / nblks,
2556           ref_best_rd - split_rd_stats->rdcost, &this_cost_valid, ftxs_mode,
2557           (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
2558       if (!this_cost_valid) {
2559         split_rd_stats->rdcost = INT64_MAX;
2560         return;
2561       }
2562       av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2563       split_rd_stats->rdcost =
2564           RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2565       if (split_rd_stats->rdcost > ref_best_rd) {
2566         split_rd_stats->rdcost = INT64_MAX;
2567         return;
2568       }
2569       block += sub_step;
2570     }
2571   }
2572 }
2573 
2574 // Search for the best transform partition(recursive)/type for a given
2575 // inter-predicted luma block. The obtained transform selection will be saved
2576 // in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t prev_level_rd,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node)2577 static AOM_INLINE void select_tx_block(
2578     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2579     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2580     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2581     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2582     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
2583     TXB_RD_INFO_NODE *rd_info_node) {
2584   assert(tx_size < TX_SIZES_ALL);
2585   av1_init_rd_stats(rd_stats);
2586   if (ref_best_rd < 0) {
2587     *is_cost_valid = 0;
2588     return;
2589   }
2590 
2591   MACROBLOCKD *const xd = &x->e_mbd;
2592   assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2593          blk_col < max_block_wide(xd, plane_bsize, 0));
2594   MB_MODE_INFO *const mbmi = xd->mi[0];
2595   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2596                                          mbmi->sb_type, tx_size);
2597   struct macroblock_plane *const p = &x->plane[0];
2598 
2599   const int try_no_split =
2600       cpi->oxcf.enable_tx64 || txsize_sqr_up_map[tx_size] != TX_64X64;
2601   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2602   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2603 
2604   // Try using current block as a single transform block without split.
2605   if (try_no_split) {
2606     try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2607                           plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2608                           ftxs_mode, rd_info_node, &no_split);
2609 
2610     // Speed features for early termination.
2611     const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2612     if (search_level) {
2613       if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2614         *is_cost_valid = 0;
2615         return;
2616       }
2617       if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2618         try_split = 0;
2619       }
2620     }
2621     if (cpi->sf.tx_sf.txb_split_cap) {
2622       if (p->eobs[block] == 0) try_split = 0;
2623     }
2624   }
2625 
2626   // ML based speed feature to skip searching for split transform blocks.
2627   if (x->e_mbd.bd == 8 && try_split &&
2628       !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2629     const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2630     if (threshold >= 0) {
2631       const int split_score =
2632           ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2633       if (split_score < -threshold) try_split = 0;
2634     }
2635   }
2636 
2637   RD_STATS split_rd_stats;
2638   split_rd_stats.rdcost = INT64_MAX;
2639   // Try splitting current block into smaller transform blocks.
2640   if (try_split) {
2641     try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2642                        plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2643                        AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2644                        rd_info_node, &split_rd_stats);
2645   }
2646 
2647   if (no_split.rd < split_rd_stats.rdcost) {
2648     ENTROPY_CONTEXT *pta = ta + blk_col;
2649     ENTROPY_CONTEXT *ptl = tl + blk_row;
2650     p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2651     av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2652     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2653                           tx_size);
2654     for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2655       for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2656         const int index =
2657             av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2658         mbmi->inter_tx_size[index] = tx_size;
2659       }
2660     }
2661     mbmi->tx_size = tx_size;
2662     update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2663     const int bw = mi_size_wide[plane_bsize];
2664     set_blk_skip(x, 0, blk_row * bw + blk_col, rd_stats->skip);
2665   } else {
2666     *rd_stats = split_rd_stats;
2667     if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2668   }
2669 }
2670 
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2671 static AOM_INLINE void choose_largest_tx_size(const AV1_COMP *const cpi,
2672                                               MACROBLOCK *x, RD_STATS *rd_stats,
2673                                               int64_t ref_best_rd,
2674                                               BLOCK_SIZE bs) {
2675   MACROBLOCKD *const xd = &x->e_mbd;
2676   MB_MODE_INFO *const mbmi = xd->mi[0];
2677   mbmi->tx_size = tx_size_from_tx_mode(bs, x->tx_mode_search_type);
2678 
2679   // If tx64 is not enabled, we need to go down to the next available size
2680   if (!cpi->oxcf.enable_tx64) {
2681     static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2682       TX_4X4,    // 4x4 transform
2683       TX_8X8,    // 8x8 transform
2684       TX_16X16,  // 16x16 transform
2685       TX_32X32,  // 32x32 transform
2686       TX_32X32,  // 64x64 transform
2687       TX_4X8,    // 4x8 transform
2688       TX_8X4,    // 8x4 transform
2689       TX_8X16,   // 8x16 transform
2690       TX_16X8,   // 16x8 transform
2691       TX_16X32,  // 16x32 transform
2692       TX_32X16,  // 32x16 transform
2693       TX_32X32,  // 32x64 transform
2694       TX_32X32,  // 64x32 transform
2695       TX_4X16,   // 4x16 transform
2696       TX_16X4,   // 16x4 transform
2697       TX_8X32,   // 8x32 transform
2698       TX_32X8,   // 32x8 transform
2699       TX_16X32,  // 16x64 transform
2700       TX_32X16,  // 64x16 transform
2701     };
2702 
2703     mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
2704   }
2705 
2706   const int skip_ctx = av1_get_skip_context(xd);
2707   const int no_skip_flag_rate = x->skip_cost[skip_ctx][0];
2708   const int skip_flag_rate = x->skip_cost[skip_ctx][1];
2709   // Skip RDcost is used only for Inter blocks
2710   const int64_t skip_rd =
2711       is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_flag_rate, 0) : INT64_MAX;
2712   const int64_t no_skip_rd = RDCOST(x->rdmult, no_skip_flag_rate, 0);
2713   const int skip_trellis = 0;
2714   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2715                        AOMMIN(no_skip_rd, skip_rd), AOM_PLANE_Y, bs,
2716                        mbmi->tx_size, cpi->sf.rd_sf.use_fast_coef_costing,
2717                        FTXS_NONE, skip_trellis);
2718 }
2719 
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2720 static AOM_INLINE void choose_smallest_tx_size(const AV1_COMP *const cpi,
2721                                                MACROBLOCK *x,
2722                                                RD_STATS *rd_stats,
2723                                                int64_t ref_best_rd,
2724                                                BLOCK_SIZE bs) {
2725   MACROBLOCKD *const xd = &x->e_mbd;
2726   MB_MODE_INFO *const mbmi = xd->mi[0];
2727 
2728   mbmi->tx_size = TX_4X4;
2729   // TODO(any) : Pass this_rd based on skip/non-skip cost
2730   const int skip_trellis = 0;
2731   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
2732                        cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE,
2733                        skip_trellis);
2734 }
2735 
2736 // Search for the best uniform transform size and type for current coding block.
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2737 static AOM_INLINE void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
2738                                                    MACROBLOCK *x,
2739                                                    RD_STATS *rd_stats,
2740                                                    int64_t ref_best_rd,
2741                                                    BLOCK_SIZE bs) {
2742   av1_invalid_rd_stats(rd_stats);
2743 
2744   MACROBLOCKD *const xd = &x->e_mbd;
2745   MB_MODE_INFO *const mbmi = xd->mi[0];
2746   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
2747   const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT;
2748   int start_tx;
2749   // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
2750   // how many times of splitting is allowed during the RD search.
2751   int init_depth;
2752 
2753   if (tx_select) {
2754     start_tx = max_rect_tx_size;
2755     init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
2756                                        is_inter_block(mbmi), &cpi->sf,
2757                                        x->tx_size_search_method);
2758   } else {
2759     const TX_SIZE chosen_tx_size =
2760         tx_size_from_tx_mode(bs, x->tx_mode_search_type);
2761     start_tx = chosen_tx_size;
2762     init_depth = MAX_TX_DEPTH;
2763   }
2764 
2765   const int skip_trellis = 0;
2766   uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
2767   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
2768   TX_SIZE best_tx_size = max_rect_tx_size;
2769   int64_t best_rd = INT64_MAX;
2770   const int num_blks = bsize_to_num_blk(bs);
2771   x->rd_model = FULL_TXFM_RD;
2772   int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
2773   for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
2774        depth++, tx_size = sub_tx_size_map[tx_size]) {
2775     if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[tx_size] == TX_64X64) {
2776       continue;
2777     }
2778 
2779     RD_STATS this_rd_stats;
2780     rd[depth] = av1_uniform_txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs,
2781                                      tx_size, FTXS_NONE, skip_trellis);
2782     if (rd[depth] < best_rd) {
2783       av1_copy_array(best_blk_skip, x->blk_skip, num_blks);
2784       av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
2785       best_tx_size = tx_size;
2786       best_rd = rd[depth];
2787       *rd_stats = this_rd_stats;
2788     }
2789     if (tx_size == TX_4X4) break;
2790     // If we are searching three depths, prune the smallest size depending
2791     // on rd results for the first two depths for low contrast blocks.
2792     if (depth > init_depth && depth != MAX_TX_DEPTH &&
2793         x->source_variance < 256) {
2794       if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
2795     }
2796   }
2797 
2798   if (rd_stats->rate != INT_MAX) {
2799     mbmi->tx_size = best_tx_size;
2800     av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
2801     av1_copy_array(x->blk_skip, best_blk_skip, num_blks);
2802   }
2803 }
2804 
2805 // Search for the best transform type for the given transform block in the
2806 // given plane/channel, and calculate the corresponding RD cost.
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)2807 static AOM_INLINE void block_rd_txfm(int plane, int block, int blk_row,
2808                                      int blk_col, BLOCK_SIZE plane_bsize,
2809                                      TX_SIZE tx_size, void *arg) {
2810   struct rdcost_block_args *args = arg;
2811   if (args->exit_early) {
2812     args->incomplete_exit = 1;
2813     return;
2814   }
2815 
2816   MACROBLOCK *const x = args->x;
2817   MACROBLOCKD *const xd = &x->e_mbd;
2818   const int is_inter = is_inter_block(xd->mi[0]);
2819   const AV1_COMP *cpi = args->cpi;
2820   ENTROPY_CONTEXT *a = args->t_above + blk_col;
2821   ENTROPY_CONTEXT *l = args->t_left + blk_row;
2822   const AV1_COMMON *cm = &cpi->common;
2823   RD_STATS this_rd_stats;
2824   av1_init_rd_stats(&this_rd_stats);
2825 
2826   if (!is_inter) {
2827     av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
2828     av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
2829   }
2830 
2831   TXB_CTX txb_ctx;
2832   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
2833   search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2834                  &txb_ctx, args->ftxs_mode, args->use_fast_coef_costing,
2835                  args->skip_trellis, args->best_rd - args->current_rd,
2836                  &this_rd_stats);
2837 
2838   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
2839     assert(!is_inter || plane_bsize < BLOCK_8X8);
2840     cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
2841   }
2842 
2843 #if CONFIG_RD_DEBUG
2844   update_txb_coeff_cost(&this_rd_stats, plane, tx_size, blk_row, blk_col,
2845                         this_rd_stats.rate);
2846 #endif  // CONFIG_RD_DEBUG
2847   av1_set_txb_context(x, plane, block, tx_size, a, l);
2848 
2849   const int blk_idx =
2850       blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
2851   if (plane == 0)
2852     set_blk_skip(x, plane, blk_idx, x->plane[plane].eobs[block] == 0);
2853   else
2854     set_blk_skip(x, plane, blk_idx, 0);
2855 
2856   int64_t rd;
2857   if (is_inter) {
2858     const int64_t no_skip_rd =
2859         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2860     const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
2861     rd = AOMMIN(no_skip_rd, skip_rd);
2862     this_rd_stats.skip &= !x->plane[plane].eobs[block];
2863   } else {
2864     // Signal non-skip for Intra blocks
2865     rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2866     this_rd_stats.skip = 0;
2867   }
2868 
2869   av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
2870 
2871   args->current_rd += rd;
2872   if (args->current_rd > args->best_rd) args->exit_early = 1;
2873 }
2874 
2875 // Search for the best transform type and return the transform coefficients RD
2876 // cost of current luma coding block with the given uniform transform size.
av1_uniform_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)2877 int64_t av1_uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
2878                              RD_STATS *rd_stats, int64_t ref_best_rd,
2879                              BLOCK_SIZE bs, TX_SIZE tx_size,
2880                              FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis) {
2881   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
2882   MACROBLOCKD *const xd = &x->e_mbd;
2883   MB_MODE_INFO *const mbmi = xd->mi[0];
2884   const int is_inter = is_inter_block(mbmi);
2885   const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT &&
2886                         block_signals_txsize(mbmi->sb_type);
2887   int tx_size_rate = 0;
2888   if (tx_select) {
2889     const int ctx = txfm_partition_context(
2890         xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
2891     tx_size_rate = is_inter ? x->txfm_partition_cost[ctx][0]
2892                             : tx_size_cost(x, bs, tx_size);
2893   }
2894   const int skip_ctx = av1_get_skip_context(xd);
2895   const int no_skip_flag_rate = x->skip_cost[skip_ctx][0];
2896   const int skip_flag_rate = x->skip_cost[skip_ctx][1];
2897   const int64_t skip_rd =
2898       is_inter ? RDCOST(x->rdmult, skip_flag_rate, 0) : INT64_MAX;
2899   const int64_t no_this_rd =
2900       RDCOST(x->rdmult, no_skip_flag_rate + tx_size_rate, 0);
2901 
2902   mbmi->tx_size = tx_size;
2903   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2904                        AOMMIN(no_this_rd, skip_rd), AOM_PLANE_Y, bs, tx_size,
2905                        cpi->sf.rd_sf.use_fast_coef_costing, ftxs_mode,
2906                        skip_trellis);
2907   if (rd_stats->rate == INT_MAX) return INT64_MAX;
2908 
2909   int64_t rd;
2910   // rdstats->rate should include all the rate except skip/non-skip cost as the
2911   // same is accounted in the caller functions after rd evaluation of all
2912   // planes. However the decisions should be done after considering the
2913   // skip/non-skip header cost
2914   if (rd_stats->skip && is_inter) {
2915     rd = RDCOST(x->rdmult, skip_flag_rate, rd_stats->sse);
2916   } else {
2917     // Intra blocks are always signalled as non-skip
2918     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_flag_rate + tx_size_rate,
2919                 rd_stats->dist);
2920     rd_stats->rate += tx_size_rate;
2921   }
2922   // Check if forcing the block to skip transform leads to smaller RD cost.
2923   if (is_inter && !rd_stats->skip && !xd->lossless[mbmi->segment_id]) {
2924     int64_t temp_skip_rd = RDCOST(x->rdmult, skip_flag_rate, rd_stats->sse);
2925     if (temp_skip_rd <= rd) {
2926       rd = temp_skip_rd;
2927       rd_stats->rate = 0;
2928       rd_stats->dist = rd_stats->sse;
2929       rd_stats->skip = 1;
2930     }
2931   }
2932 
2933   return rd;
2934 }
2935 
2936 // Search for the best transform type for a luma inter-predicted block, given
2937 // the transform block partitions.
2938 // This function is used only when some speed features are enabled.
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)2939 static AOM_INLINE void tx_block_yrd(
2940     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2941     TX_SIZE tx_size, BLOCK_SIZE plane_bsize, int depth,
2942     ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
2943     TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, int64_t ref_best_rd,
2944     RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode) {
2945   assert(tx_size < TX_SIZES_ALL);
2946   MACROBLOCKD *const xd = &x->e_mbd;
2947   MB_MODE_INFO *const mbmi = xd->mi[0];
2948   assert(is_inter_block(mbmi));
2949   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2950   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2951 
2952   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
2953 
2954   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
2955       plane_bsize, blk_row, blk_col)];
2956   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2957                                          mbmi->sb_type, tx_size);
2958 
2959   av1_init_rd_stats(rd_stats);
2960   if (tx_size == plane_tx_size) {
2961     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
2962     ENTROPY_CONTEXT *tl = left_ctx + blk_row;
2963     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2964     TXB_CTX txb_ctx;
2965     get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
2966 
2967     const int zero_blk_rate = x->coeff_costs[txs_ctx][get_plane_type(0)]
2968                                   .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2969     rd_stats->zero_rate = zero_blk_rate;
2970     tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2971                rd_stats, ftxs_mode, ref_best_rd, NULL);
2972     const int mi_width = mi_size_wide[plane_bsize];
2973     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2974             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
2975         rd_stats->skip == 1) {
2976       rd_stats->rate = zero_blk_rate;
2977       rd_stats->dist = rd_stats->sse;
2978       rd_stats->skip = 1;
2979       set_blk_skip(x, 0, blk_row * mi_width + blk_col, 1);
2980       x->plane[0].eobs[block] = 0;
2981       x->plane[0].txb_entropy_ctx[block] = 0;
2982       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2983     } else {
2984       rd_stats->skip = 0;
2985       set_blk_skip(x, 0, blk_row * mi_width + blk_col, 0);
2986     }
2987     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2988       rd_stats->rate += x->txfm_partition_cost[ctx][0];
2989     av1_set_txb_context(x, 0, block, tx_size, ta, tl);
2990     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2991                           tx_size);
2992   } else {
2993     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2994     const int txb_width = tx_size_wide_unit[sub_txs];
2995     const int txb_height = tx_size_high_unit[sub_txs];
2996     const int step = txb_height * txb_width;
2997     RD_STATS pn_rd_stats;
2998     int64_t this_rd = 0;
2999     assert(txb_width > 0 && txb_height > 0);
3000 
3001     for (int row = 0; row < tx_size_high_unit[tx_size]; row += txb_height) {
3002       for (int col = 0; col < tx_size_wide_unit[tx_size]; col += txb_width) {
3003         const int offsetr = blk_row + row;
3004         const int offsetc = blk_col + col;
3005         if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
3006 
3007         av1_init_rd_stats(&pn_rd_stats);
3008         tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3009                      depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3010                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3011         if (pn_rd_stats.rate == INT_MAX) {
3012           av1_invalid_rd_stats(rd_stats);
3013           return;
3014         }
3015         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3016         this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3017         block += step;
3018       }
3019     }
3020 
3021     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3022       rd_stats->rate += x->txfm_partition_cost[ctx][1];
3023   }
3024 }
3025 
3026 // search for tx type with tx sizes already decided for a inter-predicted luma
3027 // partition block. It's used only when some speed features are enabled.
3028 // Return value 0: early termination triggered, no valid rd cost available;
3029 //              1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)3030 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3031                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
3032                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3033   if (ref_best_rd < 0) {
3034     av1_invalid_rd_stats(rd_stats);
3035     return 0;
3036   }
3037 
3038   av1_init_rd_stats(rd_stats);
3039 
3040   MACROBLOCKD *const xd = &x->e_mbd;
3041   const struct macroblockd_plane *const pd = &xd->plane[0];
3042   const int mi_width = mi_size_wide[bsize];
3043   const int mi_height = mi_size_high[bsize];
3044   const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3045   const int bh = tx_size_high_unit[max_tx_size];
3046   const int bw = tx_size_wide_unit[max_tx_size];
3047   const int step = bw * bh;
3048   const int init_depth = get_search_init_depth(mi_width, mi_height, 1, &cpi->sf,
3049                                                x->tx_size_search_method);
3050   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3051   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3052   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3053   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3054   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3055   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3056   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3057 
3058   int64_t this_rd = 0;
3059   for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3060     for (int idx = 0; idx < mi_width; idx += bw) {
3061       RD_STATS pn_rd_stats;
3062       av1_init_rd_stats(&pn_rd_stats);
3063       tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3064                    ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3065                    &pn_rd_stats, ftxs_mode);
3066       if (pn_rd_stats.rate == INT_MAX) {
3067         av1_invalid_rd_stats(rd_stats);
3068         return 0;
3069       }
3070       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3071       this_rd +=
3072           AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3073                  RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3074       block += step;
3075     }
3076   }
3077 
3078   const int skip_ctx = av1_get_skip_context(xd);
3079   const int no_skip_flag_rate = x->skip_cost[skip_ctx][0];
3080   const int skip_flag_rate = x->skip_cost[skip_ctx][1];
3081   const int64_t skip_rd = RDCOST(x->rdmult, skip_flag_rate, rd_stats->sse);
3082   this_rd =
3083       RDCOST(x->rdmult, rd_stats->rate + no_skip_flag_rate, rd_stats->dist);
3084   if (skip_rd < this_rd) {
3085     this_rd = skip_rd;
3086     rd_stats->rate = 0;
3087     rd_stats->dist = rd_stats->sse;
3088     rd_stats->skip = 1;
3089   }
3090 
3091   const int is_cost_valid = this_rd > ref_best_rd;
3092   if (!is_cost_valid) {
3093     // reset cost value
3094     av1_invalid_rd_stats(rd_stats);
3095   }
3096   return is_cost_valid;
3097 }
3098 
3099 // Search for the best transform size and type for current inter-predicted
3100 // luma block with recursive transform block partitioning. The obtained
3101 // transform selection will be saved in xd->mi[0], the corresponding RD stats
3102 // will be saved in rd_stats. The returned value is the corresponding RD cost.
select_tx_size_and_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,TXB_RD_INFO_NODE * rd_info_tree)3103 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3104                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
3105                                        int64_t ref_best_rd,
3106                                        TXB_RD_INFO_NODE *rd_info_tree) {
3107   MACROBLOCKD *const xd = &x->e_mbd;
3108   assert(is_inter_block(xd->mi[0]));
3109   assert(bsize < BLOCK_SIZES_ALL);
3110   const int fast_tx_search = x->tx_size_search_method > USE_FULL_RD;
3111   int64_t rd_thresh = ref_best_rd;
3112   if (fast_tx_search && rd_thresh < INT64_MAX) {
3113     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3114   }
3115   assert(rd_thresh > 0);
3116   const FAST_TX_SEARCH_MODE ftxs_mode =
3117       fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3118   const struct macroblockd_plane *const pd = &xd->plane[0];
3119   assert(bsize < BLOCK_SIZES_ALL);
3120   const int mi_width = mi_size_wide[bsize];
3121   const int mi_height = mi_size_high[bsize];
3122   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3123   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3124   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3125   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3126   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3127   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3128   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3129   const int init_depth = get_search_init_depth(mi_width, mi_height, 1, &cpi->sf,
3130                                                x->tx_size_search_method);
3131   const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3132   const int bh = tx_size_high_unit[max_tx_size];
3133   const int bw = tx_size_wide_unit[max_tx_size];
3134   const int step = bw * bh;
3135   const int skip_ctx = av1_get_skip_context(xd);
3136   const int no_skip_flag_cost = x->skip_cost[skip_ctx][0];
3137   const int skip_flag_cost = x->skip_cost[skip_ctx][1];
3138   int64_t skip_rd = RDCOST(x->rdmult, skip_flag_cost, 0);
3139   int64_t no_skip_rd = RDCOST(x->rdmult, no_skip_flag_cost, 0);
3140   int block = 0;
3141 
3142   av1_init_rd_stats(rd_stats);
3143   for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3144     for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3145       const int64_t best_rd_sofar =
3146           (rd_thresh == INT64_MAX)
3147               ? INT64_MAX
3148               : (rd_thresh - (AOMMIN(skip_rd, no_skip_rd)));
3149       int is_cost_valid = 1;
3150       RD_STATS pn_rd_stats;
3151       // Search for the best transform block size and type for the sub-block.
3152       select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3153                       ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3154                       best_rd_sofar, &is_cost_valid, ftxs_mode, rd_info_tree);
3155       if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3156         av1_invalid_rd_stats(rd_stats);
3157         return INT64_MAX;
3158       }
3159       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3160       skip_rd = RDCOST(x->rdmult, skip_flag_cost, rd_stats->sse);
3161       no_skip_rd =
3162           RDCOST(x->rdmult, rd_stats->rate + no_skip_flag_cost, rd_stats->dist);
3163       block += step;
3164       if (rd_info_tree != NULL) rd_info_tree += 1;
3165     }
3166   }
3167 
3168   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3169 
3170   rd_stats->skip = (skip_rd <= no_skip_rd);
3171 
3172   // If fast_tx_search is true, only DCT and 1D DCT were tested in
3173   // select_inter_block_yrd() above. Do a better search for tx type with
3174   // tx sizes already decided.
3175   if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3176     if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3177       return INT64_MAX;
3178   }
3179 
3180   int64_t final_rd;
3181   if (rd_stats->skip) {
3182     final_rd = RDCOST(x->rdmult, skip_flag_cost, rd_stats->sse);
3183   } else {
3184     final_rd =
3185         RDCOST(x->rdmult, rd_stats->rate + no_skip_flag_cost, rd_stats->dist);
3186     if (!xd->lossless[xd->mi[0]->segment_id]) {
3187       final_rd =
3188           AOMMIN(final_rd, RDCOST(x->rdmult, skip_flag_cost, rd_stats->sse));
3189     }
3190   }
3191 
3192   return final_rd;
3193 }
3194 
3195 // Return 1 to terminate transform search early. The decision is made based on
3196 // the comparison with the reference RD cost and the model-estimated RD cost.
model_based_tx_search_prune(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int64_t ref_best_rd)3197 static AOM_INLINE int model_based_tx_search_prune(const AV1_COMP *cpi,
3198                                                   MACROBLOCK *x,
3199                                                   BLOCK_SIZE bsize,
3200                                                   int64_t ref_best_rd) {
3201   const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3202   assert(level >= 0 && level <= 2);
3203   int model_rate;
3204   int64_t model_dist;
3205   int model_skip;
3206   MACROBLOCKD *const xd = &x->e_mbd;
3207   model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3208       cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3209       NULL, NULL, NULL);
3210   if (model_skip) return 0;
3211   const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3212   // TODO(debargha, urvang): Improve the model and make the check below
3213   // tighter.
3214   static const int prune_factor_by8[] = { 3, 5 };
3215   const int factor = prune_factor_by8[level - 1];
3216   return ((model_rd * factor) >> 3) > ref_best_rd;
3217 }
3218 
3219 // Search for best transform size and type for luma inter blocks. The transform
3220 // block partitioning can be recursive resulting in non-uniform transform sizes.
3221 // The best transform size and type, if found, will be saved in the MB_MODE_INFO
3222 // structure, and the corresponding RD stats will be saved in rd_stats.
av1_pick_recursive_tx_size_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3223 void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3224                                          RD_STATS *rd_stats, BLOCK_SIZE bsize,
3225                                          int64_t ref_best_rd) {
3226   MACROBLOCKD *const xd = &x->e_mbd;
3227   assert(is_inter_block(xd->mi[0]));
3228 
3229   av1_invalid_rd_stats(rd_stats);
3230 
3231   // If modeled RD cost is a lot worse than the best so far, terminate early.
3232   if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3233       ref_best_rd != INT64_MAX) {
3234     if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3235   }
3236 
3237   // Hashing based speed feature. If the hash of the prediction residue block is
3238   // found in the hash table, use previous search results and terminate early.
3239   uint32_t hash = 0;
3240   MB_RD_RECORD *mb_rd_record = NULL;
3241   const int mi_row = x->e_mbd.mi_row;
3242   const int mi_col = x->e_mbd.mi_col;
3243   const int within_border =
3244       mi_row >= xd->tile.mi_row_start &&
3245       (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3246       mi_col >= xd->tile.mi_col_start &&
3247       (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3248   const int is_mb_rd_hash_enabled =
3249       (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3250   const int n4 = bsize_to_num_blk(bsize);
3251   if (is_mb_rd_hash_enabled) {
3252     hash = get_block_residue_hash(x, bsize);
3253     mb_rd_record = &x->mb_rd_record;
3254     const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3255     if (match_index != -1) {
3256       MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
3257       fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
3258       return;
3259     }
3260   }
3261 
3262   // If we predict that skip is the optimal RD decision - set the respective
3263   // context and terminate early.
3264   int64_t dist;
3265   if (x->predict_skip_level &&
3266       predict_skip_flag(x, bsize, &dist,
3267                         cpi->common.features.reduced_tx_set_used)) {
3268     set_skip_flag(x, rd_stats, bsize, dist);
3269     // Save the RD search results into tx_rd_record.
3270     if (is_mb_rd_hash_enabled)
3271       save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3272     return;
3273   }
3274 #if CONFIG_SPEED_STATS
3275   ++x->tx_search_count;
3276 #endif  // CONFIG_SPEED_STATS
3277 
3278   // Pre-compute residue hashes (transform block level) and find existing or
3279   // add new RD records to store and reuse rate and distortion values to speed
3280   // up TX size/type search.
3281   TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
3282   int found_rd_info = 0;
3283   if (ref_best_rd != INT64_MAX && within_border &&
3284       cpi->sf.tx_sf.use_inter_txb_hash) {
3285     found_rd_info = find_tx_size_rd_records(x, bsize, matched_rd_info);
3286   }
3287 
3288   const int64_t rd =
3289       select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd,
3290                               found_rd_info ? matched_rd_info : NULL);
3291 
3292   if (rd == INT64_MAX) {
3293     // We should always find at least one candidate unless ref_best_rd is less
3294     // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3295     // might have failed to find something better)
3296     assert(ref_best_rd != INT64_MAX);
3297     av1_invalid_rd_stats(rd_stats);
3298     return;
3299   }
3300 
3301   // Save the RD search results into tx_rd_record.
3302   if (is_mb_rd_hash_enabled) {
3303     assert(mb_rd_record != NULL);
3304     save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3305   }
3306 }
3307 
3308 // Search for the best transform size and type for current coding block, with
3309 // the assumption that all the transform blocks have a uniform size (VP9 style).
3310 // The selected transform size and type will be saved in the MB_MODE_INFO
3311 // structure; the corresponding RD stats will be saved in rd_stats.
3312 // This function may be used for both intra and inter predicted blocks.
av1_pick_uniform_tx_size_type_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)3313 void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3314                                        RD_STATS *rd_stats, BLOCK_SIZE bs,
3315                                        int64_t ref_best_rd) {
3316   MACROBLOCKD *const xd = &x->e_mbd;
3317   MB_MODE_INFO *const mbmi = xd->mi[0];
3318   assert(bs == mbmi->sb_type);
3319   const int is_inter = is_inter_block(mbmi);
3320   const int mi_row = xd->mi_row;
3321   const int mi_col = xd->mi_col;
3322 
3323   av1_init_rd_stats(rd_stats);
3324 
3325   // Hashing based speed feature for inter blocks. If the hash of the residue
3326   // block is found in the table, use previously saved search results and
3327   // terminate early.
3328   uint32_t hash = 0;
3329   MB_RD_RECORD *mb_rd_record = NULL;
3330   const int num_blks = bsize_to_num_blk(bs);
3331   if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3332     const int within_border =
3333         mi_row >= xd->tile.mi_row_start &&
3334         (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3335         mi_col >= xd->tile.mi_col_start &&
3336         (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3337     if (within_border) {
3338       hash = get_block_residue_hash(x, bs);
3339       mb_rd_record = &x->mb_rd_record;
3340       const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3341       if (match_index != -1) {
3342         MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
3343         fetch_tx_rd_info(num_blks, tx_rd_info, rd_stats, x);
3344         return;
3345       }
3346     }
3347   }
3348 
3349   // If we predict that skip is the optimal RD decision - set the respective
3350   // context and terminate early.
3351   int64_t dist;
3352   if (x->predict_skip_level && is_inter && !xd->lossless[mbmi->segment_id] &&
3353       predict_skip_flag(x, bs, &dist,
3354                         cpi->common.features.reduced_tx_set_used)) {
3355     // Populate rdstats as per skip decision
3356     set_skip_flag(x, rd_stats, bs, dist);
3357     // Save the RD search results into tx_rd_record.
3358     if (mb_rd_record) {
3359       save_tx_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3360     }
3361     return;
3362   }
3363 
3364   if (xd->lossless[mbmi->segment_id]) {
3365     // Lossless mode can only pick the smallest (4x4) transform size.
3366     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3367   } else if (x->tx_size_search_method == USE_LARGESTALL) {
3368     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3369   } else {
3370     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3371   }
3372 
3373   // Save the RD search results into tx_rd_record for possible reuse in future.
3374   if (mb_rd_record) {
3375     save_tx_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3376   }
3377 }
3378 
3379 // Calculate the transform coefficient RD cost for the given chroma coding block
3380 // Return value 0: early termination triggered, no valid rd cost available;
3381 //              1: rd cost values are valid.
av1_txfm_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3382 int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3383                   BLOCK_SIZE bsize, int64_t ref_best_rd) {
3384   av1_init_rd_stats(rd_stats);
3385   if (ref_best_rd < 0) return 0;
3386   if (!x->e_mbd.is_chroma_ref) return 1;
3387 
3388   MACROBLOCKD *const xd = &x->e_mbd;
3389   MB_MODE_INFO *const mbmi = xd->mi[0];
3390   struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3391   const int is_inter = is_inter_block(mbmi);
3392   int64_t this_rd = 0, skip_rd = 0;
3393   const BLOCK_SIZE plane_bsize =
3394       get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3395 
3396   if (is_inter) {
3397     for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3398       av1_subtract_plane(x, plane_bsize, plane);
3399   }
3400 
3401   const int skip_trellis = 0;
3402   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3403   int is_cost_valid = 1;
3404   for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3405     RD_STATS this_rd_stats;
3406     int64_t chroma_ref_best_rd = ref_best_rd;
3407     // For inter blocks, refined ref_best_rd is used for early exit
3408     // For intra blocks, even though current rd crosses ref_best_rd, early
3409     // exit is not recommended as current rd is used for gating subsequent
3410     // modes as well (say, for angular modes)
3411     // TODO(any): Extend the early exit mechanism for intra modes as well
3412     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3413         chroma_ref_best_rd != INT64_MAX)
3414       chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_rd);
3415     av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3416                          plane_bsize, uv_tx_size,
3417                          cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE,
3418                          skip_trellis);
3419     if (this_rd_stats.rate == INT_MAX) {
3420       is_cost_valid = 0;
3421       break;
3422     }
3423     av1_merge_rd_stats(rd_stats, &this_rd_stats);
3424     this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3425     skip_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3426     if (AOMMIN(this_rd, skip_rd) > ref_best_rd) {
3427       is_cost_valid = 0;
3428       break;
3429     }
3430   }
3431 
3432   if (!is_cost_valid) {
3433     // reset cost value
3434     av1_invalid_rd_stats(rd_stats);
3435   }
3436 
3437   return is_cost_valid;
3438 }
3439 
3440 // Search for the best transform type and calculate the transform coefficients
3441 // RD cost of the current coding block with the specified (uniform) transform
3442 // size and channel. The RD results will be saved in rd_stats.
av1_txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t current_rd,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,int use_fast_coef_costing,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3443 void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3444                           RD_STATS *rd_stats, int64_t ref_best_rd,
3445                           int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3446                           TX_SIZE tx_size, int use_fast_coef_costing,
3447                           FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis) {
3448   assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3449 
3450   if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[tx_size] == TX_64X64) {
3451     av1_invalid_rd_stats(rd_stats);
3452     return;
3453   }
3454 
3455   if (current_rd > ref_best_rd) {
3456     av1_invalid_rd_stats(rd_stats);
3457     return;
3458   }
3459 
3460   MACROBLOCKD *const xd = &x->e_mbd;
3461   const struct macroblockd_plane *const pd = &xd->plane[plane];
3462   struct rdcost_block_args args;
3463   av1_zero(args);
3464   args.x = x;
3465   args.cpi = cpi;
3466   args.best_rd = ref_best_rd;
3467   args.current_rd = current_rd;
3468   args.use_fast_coef_costing = use_fast_coef_costing;
3469   args.ftxs_mode = ftxs_mode;
3470   args.skip_trellis = skip_trellis;
3471   av1_init_rd_stats(&args.rd_stats);
3472 
3473   av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3474   av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3475                                          &args);
3476 
3477   MB_MODE_INFO *const mbmi = xd->mi[0];
3478   const int is_inter = is_inter_block(mbmi);
3479   const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3480 
3481   if (invalid_rd) {
3482     av1_invalid_rd_stats(rd_stats);
3483   } else {
3484     *rd_stats = args.rd_stats;
3485   }
3486 }
3487 
3488 // This function combines y and uv planes' transform search processes together
3489 // for inter-predicted blocks (including IntraBC), when the prediction is
3490 // already generated. It first does subtraction to obtain the prediction error.
3491 // Then it calls
3492 // av1_pick_recursive_tx_size_type_yrd/av1_pick_uniform_tx_size_type_yrd and
3493 // av1_txfm_uvrd sequentially and handles the early terminations
3494 // happening in those functions. At the end, it computes the
3495 // rd_stats/_y/_uv accordingly.
av1_txfm_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)3496 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3497                     RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3498                     RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3499   MACROBLOCKD *const xd = &x->e_mbd;
3500   const int skip_ctx = av1_get_skip_context(xd);
3501   const int skip_flag_cost[2] = { x->skip_cost[skip_ctx][0],
3502                                   x->skip_cost[skip_ctx][1] };
3503   const int64_t min_header_rate =
3504       mode_rate + AOMMIN(skip_flag_cost[0], skip_flag_cost[1]);
3505   // Account for minimum skip and non_skip rd.
3506   // Eventually either one of them will be added to mode_rate
3507   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3508   if (min_header_rd_possible > ref_best_rd) {
3509     av1_invalid_rd_stats(rd_stats_y);
3510     return 0;
3511   }
3512 
3513   const AV1_COMMON *cm = &cpi->common;
3514   MB_MODE_INFO *const mbmi = xd->mi[0];
3515   const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3516   const int64_t rd_thresh =
3517       ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3518   av1_init_rd_stats(rd_stats);
3519   av1_init_rd_stats(rd_stats_y);
3520   rd_stats->rate = mode_rate;
3521 
3522   // cost and distortion
3523   av1_subtract_plane(x, bsize, 0);
3524   if (x->tx_mode_search_type == TX_MODE_SELECT &&
3525       !xd->lossless[mbmi->segment_id]) {
3526     av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3527 #if CONFIG_COLLECT_RD_STATS == 2
3528     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
3529 #endif  // CONFIG_COLLECT_RD_STATS == 2
3530   } else {
3531     av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3532     memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
3533     for (int i = 0; i < xd->height * xd->width; ++i)
3534       set_blk_skip(x, 0, i, rd_stats_y->skip);
3535   }
3536 
3537   if (rd_stats_y->rate == INT_MAX) return 0;
3538 
3539   av1_merge_rd_stats(rd_stats, rd_stats_y);
3540 
3541   const int64_t non_skip_rdcosty =
3542       RDCOST(x->rdmult, rd_stats->rate + skip_flag_cost[0], rd_stats->dist);
3543   const int64_t skip_rdcosty =
3544       RDCOST(x->rdmult, mode_rate + skip_flag_cost[1], rd_stats->sse);
3545   const int64_t min_rdcosty = AOMMIN(non_skip_rdcosty, skip_rdcosty);
3546   if (min_rdcosty > ref_best_rd) {
3547     const int64_t tokenonly_rdy =
3548         AOMMIN(RDCOST(x->rdmult, rd_stats_y->rate, rd_stats_y->dist),
3549                RDCOST(x->rdmult, 0, rd_stats_y->sse));
3550     // Invalidate rd_stats_y to skip the rest of the motion modes search
3551     if (tokenonly_rdy -
3552             (tokenonly_rdy >> cpi->sf.inter_sf.prune_motion_mode_level) >
3553         rd_thresh) {
3554       av1_invalid_rd_stats(rd_stats_y);
3555     }
3556     return 0;
3557   }
3558 
3559   av1_init_rd_stats(rd_stats_uv);
3560   const int num_planes = av1_num_planes(cm);
3561   if (num_planes > 1) {
3562     int64_t ref_best_chroma_rd = ref_best_rd;
3563     // Calculate best rd cost possible for chroma
3564     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
3565         (ref_best_chroma_rd != INT64_MAX)) {
3566       ref_best_chroma_rd =
3567           (ref_best_chroma_rd - AOMMIN(non_skip_rdcosty, skip_rdcosty));
3568     }
3569     const int is_cost_valid_uv =
3570         av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
3571     if (!is_cost_valid_uv) return 0;
3572     av1_merge_rd_stats(rd_stats, rd_stats_uv);
3573   }
3574 
3575   int choose_skip = rd_stats->skip;
3576   if (!choose_skip && !xd->lossless[mbmi->segment_id]) {
3577     const int64_t rdcost_no_skip = RDCOST(
3578         x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_flag_cost[0],
3579         rd_stats->dist);
3580     const int64_t rdcost_skip =
3581         RDCOST(x->rdmult, skip_flag_cost[1], rd_stats->sse);
3582     if (rdcost_no_skip >= rdcost_skip) choose_skip = 1;
3583   }
3584   if (choose_skip) {
3585     rd_stats_y->rate = 0;
3586     rd_stats_uv->rate = 0;
3587     rd_stats->rate = mode_rate + skip_flag_cost[1];
3588     rd_stats->dist = rd_stats->sse;
3589     rd_stats_y->dist = rd_stats_y->sse;
3590     rd_stats_uv->dist = rd_stats_uv->sse;
3591     mbmi->skip = 1;
3592     if (rd_stats->skip) {
3593       const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3594       if (tmprd > ref_best_rd) return 0;
3595     }
3596   } else {
3597     rd_stats->rate += skip_flag_cost[0];
3598     mbmi->skip = 0;
3599   }
3600 
3601   return 1;
3602 }
3603