1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/encoder/tune_vmaf.h"
13 
14 #include "aom_dsp/psnr.h"
15 #include "aom_dsp/vmaf.h"
16 #include "aom_ports/system_state.h"
17 #include "av1/encoder/extend.h"
18 #include "av1/encoder/rdopt.h"
19 
20 static const double kBaselineVmaf = 97.42773;
21 
22 // TODO(sdeng): Add the SIMD implementation.
highbd_unsharp_rect(const uint16_t * source,int source_stride,const uint16_t * blurred,int blurred_stride,uint16_t * dst,int dst_stride,int w,int h,double amount,int bit_depth)23 static AOM_INLINE void highbd_unsharp_rect(const uint16_t *source,
24                                            int source_stride,
25                                            const uint16_t *blurred,
26                                            int blurred_stride, uint16_t *dst,
27                                            int dst_stride, int w, int h,
28                                            double amount, int bit_depth) {
29   const int max_value = (1 << bit_depth) - 1;
30   for (int i = 0; i < h; ++i) {
31     for (int j = 0; j < w; ++j) {
32       const double val =
33           (double)source[j] + amount * ((double)source[j] - (double)blurred[j]);
34       dst[j] = (uint16_t)clamp((int)(val + 0.5), 0, max_value);
35     }
36     source += source_stride;
37     blurred += blurred_stride;
38     dst += dst_stride;
39   }
40 }
41 
unsharp_rect(const uint8_t * source,int source_stride,const uint8_t * blurred,int blurred_stride,uint8_t * dst,int dst_stride,int w,int h,double amount)42 static AOM_INLINE void unsharp_rect(const uint8_t *source, int source_stride,
43                                     const uint8_t *blurred, int blurred_stride,
44                                     uint8_t *dst, int dst_stride, int w, int h,
45                                     double amount) {
46   for (int i = 0; i < h; ++i) {
47     for (int j = 0; j < w; ++j) {
48       const double val =
49           (double)source[j] + amount * ((double)source[j] - (double)blurred[j]);
50       dst[j] = (uint8_t)clamp((int)(val + 0.5), 0, 255);
51     }
52     source += source_stride;
53     blurred += blurred_stride;
54     dst += dst_stride;
55   }
56 }
57 
unsharp(const AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * source,const YV12_BUFFER_CONFIG * blurred,const YV12_BUFFER_CONFIG * dst,double amount)58 static AOM_INLINE void unsharp(const AV1_COMP *const cpi,
59                                const YV12_BUFFER_CONFIG *source,
60                                const YV12_BUFFER_CONFIG *blurred,
61                                const YV12_BUFFER_CONFIG *dst, double amount) {
62   const int bit_depth = cpi->td.mb.e_mbd.bd;
63   if (bit_depth > 8) {
64     highbd_unsharp_rect(CONVERT_TO_SHORTPTR(source->y_buffer), source->y_stride,
65                         CONVERT_TO_SHORTPTR(blurred->y_buffer),
66                         blurred->y_stride, CONVERT_TO_SHORTPTR(dst->y_buffer),
67                         dst->y_stride, source->y_width, source->y_height,
68                         amount, bit_depth);
69   } else {
70     unsharp_rect(source->y_buffer, source->y_stride, blurred->y_buffer,
71                  blurred->y_stride, dst->y_buffer, dst->y_stride,
72                  source->y_width, source->y_height, amount);
73   }
74 }
75 
76 // 8-tap Gaussian convolution filter with sigma = 1.0, sums to 128,
77 // all co-efficients must be even.
78 DECLARE_ALIGNED(16, static const int16_t, gauss_filter[8]) = { 0,  8, 30, 52,
79                                                                30, 8, 0,  0 };
gaussian_blur(const int bit_depth,const YV12_BUFFER_CONFIG * source,const YV12_BUFFER_CONFIG * dst)80 static AOM_INLINE void gaussian_blur(const int bit_depth,
81                                      const YV12_BUFFER_CONFIG *source,
82                                      const YV12_BUFFER_CONFIG *dst) {
83   const int block_size = BLOCK_128X128;
84   const int block_w = mi_size_wide[block_size] * 4;
85   const int block_h = mi_size_high[block_size] * 4;
86   const int num_cols = (source->y_width + block_w - 1) / block_w;
87   const int num_rows = (source->y_height + block_h - 1) / block_h;
88   int row, col;
89 
90   ConvolveParams conv_params = get_conv_params(0, 0, bit_depth);
91   InterpFilterParams filter = { .filter_ptr = gauss_filter,
92                                 .taps = 8,
93                                 .subpel_shifts = 0,
94                                 .interp_filter = EIGHTTAP_REGULAR };
95 
96   for (row = 0; row < num_rows; ++row) {
97     for (col = 0; col < num_cols; ++col) {
98       const int row_offset_y = row * block_h;
99       const int col_offset_y = col * block_w;
100 
101       uint8_t *src_buf =
102           source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
103       uint8_t *dst_buf =
104           dst->y_buffer + row_offset_y * dst->y_stride + col_offset_y;
105 
106       if (bit_depth > 8) {
107         av1_highbd_convolve_2d_sr(
108             CONVERT_TO_SHORTPTR(src_buf), source->y_stride,
109             CONVERT_TO_SHORTPTR(dst_buf), dst->y_stride, block_w, block_h,
110             &filter, &filter, 0, 0, &conv_params, bit_depth);
111       } else {
112         av1_convolve_2d_sr(src_buf, source->y_stride, dst_buf, dst->y_stride,
113                            block_w, block_h, &filter, &filter, 0, 0,
114                            &conv_params);
115       }
116     }
117   }
118 }
119 
frame_average_variance(const AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const frame)120 static double frame_average_variance(const AV1_COMP *const cpi,
121                                      const YV12_BUFFER_CONFIG *const frame) {
122   const uint8_t *const y_buffer = frame->y_buffer;
123   const int y_stride = frame->y_stride;
124   const BLOCK_SIZE block_size = BLOCK_64X64;
125 
126   const int block_w = mi_size_wide[block_size] * 4;
127   const int block_h = mi_size_high[block_size] * 4;
128   int row, col;
129   const int bit_depth = cpi->td.mb.e_mbd.bd;
130   double var = 0.0, var_count = 0.0;
131 
132   // Loop through each block.
133   for (row = 0; row < frame->y_height / block_h; ++row) {
134     for (col = 0; col < frame->y_width / block_w; ++col) {
135       struct buf_2d buf;
136       const int row_offset_y = row * block_h;
137       const int col_offset_y = col * block_w;
138 
139       buf.buf = (uint8_t *)y_buffer + row_offset_y * y_stride + col_offset_y;
140       buf.stride = y_stride;
141 
142       if (bit_depth > 8) {
143         var += av1_high_get_sby_perpixel_variance(cpi, &buf, block_size,
144                                                   bit_depth);
145       } else {
146         var += av1_get_sby_perpixel_variance(cpi, &buf, block_size);
147       }
148       var_count += 1.0;
149     }
150   }
151   var /= var_count;
152   return var;
153 }
154 
cal_approx_vmaf(const AV1_COMP * const cpi,double source_variance,YV12_BUFFER_CONFIG * const source,YV12_BUFFER_CONFIG * const sharpened)155 static double cal_approx_vmaf(const AV1_COMP *const cpi, double source_variance,
156                               YV12_BUFFER_CONFIG *const source,
157                               YV12_BUFFER_CONFIG *const sharpened) {
158   const int bit_depth = cpi->td.mb.e_mbd.bd;
159   double new_vmaf;
160   aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, sharpened, bit_depth,
161                 &new_vmaf);
162   const double sharpened_var = frame_average_variance(cpi, sharpened);
163   return source_variance / sharpened_var * (new_vmaf - kBaselineVmaf);
164 }
165 
find_best_frame_unsharp_amount_loop(const AV1_COMP * const cpi,YV12_BUFFER_CONFIG * const source,YV12_BUFFER_CONFIG * const blurred,YV12_BUFFER_CONFIG * const sharpened,double best_vmaf,const double baseline_variance,const double unsharp_amount_start,const double step_size,const int max_loop_count,const double max_amount)166 static double find_best_frame_unsharp_amount_loop(
167     const AV1_COMP *const cpi, YV12_BUFFER_CONFIG *const source,
168     YV12_BUFFER_CONFIG *const blurred, YV12_BUFFER_CONFIG *const sharpened,
169     double best_vmaf, const double baseline_variance,
170     const double unsharp_amount_start, const double step_size,
171     const int max_loop_count, const double max_amount) {
172   const double min_amount = 0.0;
173   int loop_count = 0;
174   double approx_vmaf = best_vmaf;
175   double unsharp_amount = unsharp_amount_start;
176   do {
177     best_vmaf = approx_vmaf;
178     unsharp_amount += step_size;
179     if (unsharp_amount > max_amount || unsharp_amount < min_amount) break;
180     unsharp(cpi, source, blurred, sharpened, unsharp_amount);
181     approx_vmaf = cal_approx_vmaf(cpi, baseline_variance, source, sharpened);
182 
183     loop_count++;
184   } while (approx_vmaf > best_vmaf && loop_count < max_loop_count);
185   unsharp_amount =
186       approx_vmaf > best_vmaf ? unsharp_amount : unsharp_amount - step_size;
187   return AOMMIN(max_amount, AOMMAX(unsharp_amount, min_amount));
188 }
189 
find_best_frame_unsharp_amount(const AV1_COMP * const cpi,YV12_BUFFER_CONFIG * const source,YV12_BUFFER_CONFIG * const blurred,const double unsharp_amount_start,const double step_size,const int max_loop_count,const double max_filter_amount)190 static double find_best_frame_unsharp_amount(const AV1_COMP *const cpi,
191                                              YV12_BUFFER_CONFIG *const source,
192                                              YV12_BUFFER_CONFIG *const blurred,
193                                              const double unsharp_amount_start,
194                                              const double step_size,
195                                              const int max_loop_count,
196                                              const double max_filter_amount) {
197   const AV1_COMMON *const cm = &cpi->common;
198   const int width = source->y_width;
199   const int height = source->y_height;
200 
201   YV12_BUFFER_CONFIG sharpened;
202   memset(&sharpened, 0, sizeof(sharpened));
203   aom_alloc_frame_buffer(
204       &sharpened, width, height, 1, 1, cm->seq_params.use_highbitdepth,
205       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
206 
207   const double baseline_variance = frame_average_variance(cpi, source);
208   double unsharp_amount;
209   if (unsharp_amount_start <= step_size) {
210     unsharp_amount = find_best_frame_unsharp_amount_loop(
211         cpi, source, blurred, &sharpened, 0.0, baseline_variance, 0.0,
212         step_size, max_loop_count, max_filter_amount);
213   } else {
214     double a0 = unsharp_amount_start - step_size, a1 = unsharp_amount_start;
215     double v0, v1;
216     unsharp(cpi, source, blurred, &sharpened, a0);
217     v0 = cal_approx_vmaf(cpi, baseline_variance, source, &sharpened);
218     unsharp(cpi, source, blurred, &sharpened, a1);
219     v1 = cal_approx_vmaf(cpi, baseline_variance, source, &sharpened);
220     if (fabs(v0 - v1) < 0.01) {
221       unsharp_amount = a0;
222     } else if (v0 > v1) {
223       unsharp_amount = find_best_frame_unsharp_amount_loop(
224           cpi, source, blurred, &sharpened, v0, baseline_variance, a0,
225           -step_size, max_loop_count, max_filter_amount);
226     } else {
227       unsharp_amount = find_best_frame_unsharp_amount_loop(
228           cpi, source, blurred, &sharpened, v1, baseline_variance, a1,
229           step_size, max_loop_count, max_filter_amount);
230     }
231   }
232 
233   aom_free_frame_buffer(&sharpened);
234   return unsharp_amount;
235 }
236 
av1_vmaf_frame_preprocessing(AV1_COMP * const cpi,YV12_BUFFER_CONFIG * const source)237 void av1_vmaf_frame_preprocessing(AV1_COMP *const cpi,
238                                   YV12_BUFFER_CONFIG *const source) {
239   aom_clear_system_state();
240   const AV1_COMMON *const cm = &cpi->common;
241   const int bit_depth = cpi->td.mb.e_mbd.bd;
242   const int width = source->y_width;
243   const int height = source->y_height;
244 
245   YV12_BUFFER_CONFIG source_extended, blurred;
246   memset(&source_extended, 0, sizeof(source_extended));
247   memset(&blurred, 0, sizeof(blurred));
248   aom_alloc_frame_buffer(
249       &source_extended, width, height, 1, 1, cm->seq_params.use_highbitdepth,
250       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
251   aom_alloc_frame_buffer(
252       &blurred, width, height, 1, 1, cm->seq_params.use_highbitdepth,
253       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
254 
255   av1_copy_and_extend_frame(source, &source_extended);
256   gaussian_blur(bit_depth, &source_extended, &blurred);
257   aom_free_frame_buffer(&source_extended);
258 
259   const double best_frame_unsharp_amount = find_best_frame_unsharp_amount(
260       cpi, source, &blurred, cpi->last_frame_unsharp_amount, 0.05, 20, 1.01);
261   cpi->last_frame_unsharp_amount = best_frame_unsharp_amount;
262 
263   unsharp(cpi, source, &blurred, source, best_frame_unsharp_amount);
264   aom_free_frame_buffer(&blurred);
265   aom_clear_system_state();
266 }
267 
av1_vmaf_blk_preprocessing(AV1_COMP * const cpi,YV12_BUFFER_CONFIG * const source)268 void av1_vmaf_blk_preprocessing(AV1_COMP *const cpi,
269                                 YV12_BUFFER_CONFIG *const source) {
270   aom_clear_system_state();
271   const AV1_COMMON *const cm = &cpi->common;
272   const int width = source->y_width;
273   const int height = source->y_height;
274   const int bit_depth = cpi->td.mb.e_mbd.bd;
275 
276   YV12_BUFFER_CONFIG source_extended, blurred;
277   memset(&blurred, 0, sizeof(blurred));
278   memset(&source_extended, 0, sizeof(source_extended));
279   aom_alloc_frame_buffer(
280       &blurred, width, height, 1, 1, cm->seq_params.use_highbitdepth,
281       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
282   aom_alloc_frame_buffer(
283       &source_extended, width, height, 1, 1, cm->seq_params.use_highbitdepth,
284       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
285 
286   av1_copy_and_extend_frame(source, &source_extended);
287   gaussian_blur(bit_depth, &source_extended, &blurred);
288   aom_free_frame_buffer(&source_extended);
289 
290   const double best_frame_unsharp_amount = find_best_frame_unsharp_amount(
291       cpi, source, &blurred, cpi->last_frame_unsharp_amount, 0.05, 20, 1.01);
292   cpi->last_frame_unsharp_amount = best_frame_unsharp_amount;
293 
294   const int block_size = BLOCK_64X64;
295   const int block_w = mi_size_wide[block_size] * 4;
296   const int block_h = mi_size_high[block_size] * 4;
297   const int num_cols = (source->y_width + block_w - 1) / block_w;
298   const int num_rows = (source->y_height + block_h - 1) / block_h;
299   double *best_unsharp_amounts =
300       aom_malloc(sizeof(*best_unsharp_amounts) * num_cols * num_rows);
301   memset(best_unsharp_amounts, 0,
302          sizeof(*best_unsharp_amounts) * num_cols * num_rows);
303 
304   YV12_BUFFER_CONFIG source_block, blurred_block;
305   memset(&source_block, 0, sizeof(source_block));
306   memset(&blurred_block, 0, sizeof(blurred_block));
307   aom_alloc_frame_buffer(
308       &source_block, block_w, block_h, 1, 1, cm->seq_params.use_highbitdepth,
309       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
310   aom_alloc_frame_buffer(
311       &blurred_block, block_w, block_h, 1, 1, cm->seq_params.use_highbitdepth,
312       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
313 
314   for (int row = 0; row < num_rows; ++row) {
315     for (int col = 0; col < num_cols; ++col) {
316       const int row_offset_y = row * block_h;
317       const int col_offset_y = col * block_w;
318       const int block_width = AOMMIN(width - col_offset_y, block_w);
319       const int block_height = AOMMIN(height - row_offset_y, block_h);
320       const int index = col + row * num_cols;
321 
322       if (bit_depth > 8) {
323         uint16_t *frame_src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
324                                   row_offset_y * source->y_stride +
325                                   col_offset_y;
326         uint16_t *frame_blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
327                                       row_offset_y * blurred.y_stride +
328                                       col_offset_y;
329         uint16_t *blurred_dst = CONVERT_TO_SHORTPTR(blurred_block.y_buffer);
330         uint16_t *src_dst = CONVERT_TO_SHORTPTR(source_block.y_buffer);
331 
332         // Copy block from source frame.
333         for (int i = 0; i < block_h; ++i) {
334           for (int j = 0; j < block_w; ++j) {
335             if (i >= block_height || j >= block_width) {
336               src_dst[j] = 0;
337               blurred_dst[j] = 0;
338             } else {
339               src_dst[j] = frame_src_buf[j];
340               blurred_dst[j] = frame_blurred_buf[j];
341             }
342           }
343           frame_src_buf += source->y_stride;
344           frame_blurred_buf += blurred.y_stride;
345           src_dst += source_block.y_stride;
346           blurred_dst += blurred_block.y_stride;
347         }
348       } else {
349         uint8_t *frame_src_buf =
350             source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
351         uint8_t *frame_blurred_buf =
352             blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
353         uint8_t *blurred_dst = blurred_block.y_buffer;
354         uint8_t *src_dst = source_block.y_buffer;
355 
356         // Copy block from source frame.
357         for (int i = 0; i < block_h; ++i) {
358           for (int j = 0; j < block_w; ++j) {
359             if (i >= block_height || j >= block_width) {
360               src_dst[j] = 0;
361               blurred_dst[j] = 0;
362             } else {
363               src_dst[j] = frame_src_buf[j];
364               blurred_dst[j] = frame_blurred_buf[j];
365             }
366           }
367           frame_src_buf += source->y_stride;
368           frame_blurred_buf += blurred.y_stride;
369           src_dst += source_block.y_stride;
370           blurred_dst += blurred_block.y_stride;
371         }
372       }
373 
374       best_unsharp_amounts[index] = find_best_frame_unsharp_amount(
375           cpi, &source_block, &blurred_block, best_frame_unsharp_amount, 0.1, 3,
376           1.5);
377     }
378   }
379 
380   // Apply best blur amounts
381   for (int row = 0; row < num_rows; ++row) {
382     for (int col = 0; col < num_cols; ++col) {
383       const int row_offset_y = row * block_h;
384       const int col_offset_y = col * block_w;
385       const int block_width = AOMMIN(source->y_width - col_offset_y, block_w);
386       const int block_height = AOMMIN(source->y_height - row_offset_y, block_h);
387       const int index = col + row * num_cols;
388 
389       if (bit_depth > 8) {
390         uint16_t *src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
391                             row_offset_y * source->y_stride + col_offset_y;
392         uint16_t *blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
393                                 row_offset_y * blurred.y_stride + col_offset_y;
394         highbd_unsharp_rect(src_buf, source->y_stride, blurred_buf,
395                             blurred.y_stride, src_buf, source->y_stride,
396                             block_width, block_height,
397                             best_unsharp_amounts[index], bit_depth);
398       } else {
399         uint8_t *src_buf =
400             source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
401         uint8_t *blurred_buf =
402             blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
403         unsharp_rect(src_buf, source->y_stride, blurred_buf, blurred.y_stride,
404                      src_buf, source->y_stride, block_width, block_height,
405                      best_unsharp_amounts[index]);
406       }
407     }
408   }
409 
410   aom_free_frame_buffer(&source_block);
411   aom_free_frame_buffer(&blurred_block);
412   aom_free_frame_buffer(&blurred);
413   aom_free(best_unsharp_amounts);
414   aom_clear_system_state();
415 }
416 
417 typedef struct FrameData {
418   const YV12_BUFFER_CONFIG *source, *blurred;
419   int block_w, block_h, num_rows, num_cols, row, col, bit_depth;
420 } FrameData;
421 
422 // A callback function used to pass data to VMAF.
423 // Returns 0 after reading a frame.
424 // Returns 2 when there is no more frame to read.
update_frame(float * ref_data,float * main_data,float * temp_data,int stride,void * user_data)425 static int update_frame(float *ref_data, float *main_data, float *temp_data,
426                         int stride, void *user_data) {
427   FrameData *frames = (FrameData *)user_data;
428   const int width = frames->source->y_width;
429   const int height = frames->source->y_height;
430   const int row = frames->row;
431   const int col = frames->col;
432   const int num_rows = frames->num_rows;
433   const int num_cols = frames->num_cols;
434   const int block_w = frames->block_w;
435   const int block_h = frames->block_h;
436   const YV12_BUFFER_CONFIG *source = frames->source;
437   const YV12_BUFFER_CONFIG *blurred = frames->blurred;
438   const int bit_depth = frames->bit_depth;
439   const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
440   (void)temp_data;
441   stride /= (int)sizeof(*ref_data);
442 
443   for (int i = 0; i < height; ++i) {
444     float *ref, *main;
445     ref = ref_data + i * stride;
446     main = main_data + i * stride;
447     if (bit_depth == 8) {
448       uint8_t *src;
449       src = source->y_buffer + i * source->y_stride;
450       for (int j = 0; j < width; ++j) {
451         ref[j] = main[j] = (float)src[j];
452       }
453     } else {
454       uint16_t *src;
455       src = CONVERT_TO_SHORTPTR(source->y_buffer) + i * source->y_stride;
456       for (int j = 0; j < width; ++j) {
457         ref[j] = main[j] = scale_factor * (float)src[j];
458       }
459     }
460   }
461   if (row < num_rows && col < num_cols) {
462     // Set current block
463     const int row_offset = row * block_h;
464     const int col_offset = col * block_w;
465     const int block_width = AOMMIN(width - col_offset, block_w);
466     const int block_height = AOMMIN(height - row_offset, block_h);
467 
468     float *main_buf = main_data + col_offset + row_offset * stride;
469     if (bit_depth == 8) {
470       uint8_t *blurred_buf =
471           blurred->y_buffer + row_offset * blurred->y_stride + col_offset;
472       for (int i = 0; i < block_height; ++i) {
473         for (int j = 0; j < block_width; ++j) {
474           main_buf[j] = (float)blurred_buf[j];
475         }
476         main_buf += stride;
477         blurred_buf += blurred->y_stride;
478       }
479     } else {
480       uint16_t *blurred_buf = CONVERT_TO_SHORTPTR(blurred->y_buffer) +
481                               row_offset * blurred->y_stride + col_offset;
482       for (int i = 0; i < block_height; ++i) {
483         for (int j = 0; j < block_width; ++j) {
484           main_buf[j] = scale_factor * (float)blurred_buf[j];
485         }
486         main_buf += stride;
487         blurred_buf += blurred->y_stride;
488       }
489     }
490 
491     frames->col++;
492     if (frames->col >= num_cols) {
493       frames->col = 0;
494       frames->row++;
495     }
496     return 0;
497   } else {
498     return 2;
499   }
500 }
501 
av1_set_mb_vmaf_rdmult_scaling(AV1_COMP * cpi)502 void av1_set_mb_vmaf_rdmult_scaling(AV1_COMP *cpi) {
503   AV1_COMMON *cm = &cpi->common;
504   const int y_width = cpi->source->y_width;
505   const int y_height = cpi->source->y_height;
506   const int resized_block_size = BLOCK_32X32;
507   const int resize_factor = 2;
508   const int bit_depth = cpi->td.mb.e_mbd.bd;
509 
510   aom_clear_system_state();
511   YV12_BUFFER_CONFIG resized_source;
512   memset(&resized_source, 0, sizeof(resized_source));
513   aom_alloc_frame_buffer(
514       &resized_source, y_width / resize_factor, y_height / resize_factor, 1, 1,
515       cm->seq_params.use_highbitdepth, cpi->oxcf.border_in_pixels,
516       cm->features.byte_alignment);
517   av1_resize_and_extend_frame(cpi->source, &resized_source, bit_depth,
518                               av1_num_planes(cm));
519 
520   const int resized_y_width = resized_source.y_width;
521   const int resized_y_height = resized_source.y_height;
522   const int resized_block_w = mi_size_wide[resized_block_size] * 4;
523   const int resized_block_h = mi_size_high[resized_block_size] * 4;
524   const int num_cols =
525       (resized_y_width + resized_block_w - 1) / resized_block_w;
526   const int num_rows =
527       (resized_y_height + resized_block_h - 1) / resized_block_h;
528 
529   YV12_BUFFER_CONFIG blurred;
530   memset(&blurred, 0, sizeof(blurred));
531   aom_alloc_frame_buffer(&blurred, resized_y_width, resized_y_height, 1, 1,
532                          cm->seq_params.use_highbitdepth,
533                          cpi->oxcf.border_in_pixels,
534                          cm->features.byte_alignment);
535   gaussian_blur(bit_depth, &resized_source, &blurred);
536 
537   double *scores = aom_malloc(sizeof(*scores) * (num_rows * num_cols));
538   memset(scores, 0, sizeof(*scores) * (num_rows * num_cols));
539   FrameData frame_data;
540   frame_data.source = &resized_source;
541   frame_data.blurred = &blurred;
542   frame_data.block_w = resized_block_w;
543   frame_data.block_h = resized_block_h;
544   frame_data.num_rows = num_rows;
545   frame_data.num_cols = num_cols;
546   frame_data.row = 0;
547   frame_data.col = 0;
548   frame_data.bit_depth = bit_depth;
549   aom_calc_vmaf_multi_frame(&frame_data, cpi->oxcf.vmaf_model_path,
550                             update_frame, resized_y_width, resized_y_height,
551                             bit_depth, scores);
552 
553   // Loop through each 'block_size' block.
554   for (int row = 0; row < num_rows; ++row) {
555     for (int col = 0; col < num_cols; ++col) {
556       const int index = row * num_cols + col;
557       const int row_offset_y = row * resized_block_h;
558       const int col_offset_y = col * resized_block_w;
559 
560       uint8_t *const orig_buf = resized_source.y_buffer +
561                                 row_offset_y * resized_source.y_stride +
562                                 col_offset_y;
563       uint8_t *const blurred_buf =
564           blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
565 
566       const double vmaf = scores[index];
567       const double dvmaf = kBaselineVmaf - vmaf;
568       unsigned int sse;
569       cpi->fn_ptr[resized_block_size].vf(orig_buf, resized_source.y_stride,
570                                          blurred_buf, blurred.y_stride, &sse);
571 
572       const double mse =
573           (double)sse / (double)(resized_y_width * resized_y_height);
574       double weight;
575       const double eps = 0.01 / (num_rows * num_cols);
576       if (dvmaf < eps || mse < eps) {
577         weight = 1.0;
578       } else {
579         weight = mse / dvmaf;
580       }
581 
582       // Normalize it with a data fitted model.
583       weight = 6.0 * (1.0 - exp(-0.05 * weight)) + 0.8;
584       cpi->vmaf_rdmult_scaling_factors[index] = weight;
585     }
586   }
587 
588   aom_free_frame_buffer(&resized_source);
589   aom_free_frame_buffer(&blurred);
590   aom_free(scores);
591   aom_clear_system_state();
592 }
593 
av1_set_vmaf_rdmult(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const int mi_row,const int mi_col,int * const rdmult)594 void av1_set_vmaf_rdmult(const AV1_COMP *const cpi, MACROBLOCK *const x,
595                          const BLOCK_SIZE bsize, const int mi_row,
596                          const int mi_col, int *const rdmult) {
597   const AV1_COMMON *const cm = &cpi->common;
598 
599   const int bsize_base = BLOCK_64X64;
600   const int num_mi_w = mi_size_wide[bsize_base];
601   const int num_mi_h = mi_size_high[bsize_base];
602   const int num_cols = (cm->mi_params.mi_cols + num_mi_w - 1) / num_mi_w;
603   const int num_rows = (cm->mi_params.mi_rows + num_mi_h - 1) / num_mi_h;
604   const int num_bcols = (mi_size_wide[bsize] + num_mi_w - 1) / num_mi_w;
605   const int num_brows = (mi_size_high[bsize] + num_mi_h - 1) / num_mi_h;
606   int row, col;
607   double num_of_mi = 0.0;
608   double geom_mean_of_scale = 0.0;
609 
610   aom_clear_system_state();
611   for (row = mi_row / num_mi_w;
612        row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
613     for (col = mi_col / num_mi_h;
614          col < num_cols && col < mi_col / num_mi_h + num_bcols; ++col) {
615       const int index = row * num_cols + col;
616       geom_mean_of_scale += log(cpi->vmaf_rdmult_scaling_factors[index]);
617       num_of_mi += 1.0;
618     }
619   }
620   geom_mean_of_scale = exp(geom_mean_of_scale / num_of_mi);
621 
622   *rdmult = (int)((double)(*rdmult) * geom_mean_of_scale + 0.5);
623   *rdmult = AOMMAX(*rdmult, 0);
624   set_error_per_bit(x, *rdmult);
625   aom_clear_system_state();
626 }
627 
628 // TODO(sdeng): replace them with the SIMD versions.
highbd_image_sad_c(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h)629 static AOM_INLINE double highbd_image_sad_c(const uint16_t *src, int src_stride,
630                                             const uint16_t *ref, int ref_stride,
631                                             int w, int h) {
632   double accum = 0.0;
633   int i, j;
634 
635   for (i = 0; i < h; ++i) {
636     for (j = 0; j < w; ++j) {
637       double img1px = src[i * src_stride + j];
638       double img2px = ref[i * ref_stride + j];
639 
640       accum += fabs(img1px - img2px);
641     }
642   }
643 
644   return accum / (double)(h * w);
645 }
646 
image_sad_c(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int w,int h)647 static AOM_INLINE double image_sad_c(const uint8_t *src, int src_stride,
648                                      const uint8_t *ref, int ref_stride, int w,
649                                      int h) {
650   double accum = 0.0;
651   int i, j;
652 
653   for (i = 0; i < h; ++i) {
654     for (j = 0; j < w; ++j) {
655       double img1px = src[i * src_stride + j];
656       double img2px = ref[i * ref_stride + j];
657 
658       accum += fabs(img1px - img2px);
659     }
660   }
661 
662   return accum / (double)(h * w);
663 }
664 
calc_vmaf_motion_score(const AV1_COMP * const cpi,const AV1_COMMON * const cm,const YV12_BUFFER_CONFIG * const cur,const YV12_BUFFER_CONFIG * const last,const YV12_BUFFER_CONFIG * const next)665 static AOM_INLINE double calc_vmaf_motion_score(
666     const AV1_COMP *const cpi, const AV1_COMMON *const cm,
667     const YV12_BUFFER_CONFIG *const cur, const YV12_BUFFER_CONFIG *const last,
668     const YV12_BUFFER_CONFIG *const next) {
669   const int y_width = cur->y_width;
670   const int y_height = cur->y_height;
671   YV12_BUFFER_CONFIG blurred_cur, blurred_last, blurred_next;
672   const int bit_depth = cpi->td.mb.e_mbd.bd;
673 
674   memset(&blurred_cur, 0, sizeof(blurred_cur));
675   memset(&blurred_last, 0, sizeof(blurred_last));
676   memset(&blurred_next, 0, sizeof(blurred_next));
677 
678   aom_alloc_frame_buffer(
679       &blurred_cur, y_width, y_height, 1, 1, cm->seq_params.use_highbitdepth,
680       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
681   aom_alloc_frame_buffer(
682       &blurred_last, y_width, y_height, 1, 1, cm->seq_params.use_highbitdepth,
683       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
684   aom_alloc_frame_buffer(
685       &blurred_next, y_width, y_height, 1, 1, cm->seq_params.use_highbitdepth,
686       cpi->oxcf.border_in_pixels, cm->features.byte_alignment);
687 
688   gaussian_blur(bit_depth, cur, &blurred_cur);
689   gaussian_blur(bit_depth, last, &blurred_last);
690   if (next) gaussian_blur(bit_depth, next, &blurred_next);
691 
692   double motion1, motion2 = 65536.0;
693   if (bit_depth > 8) {
694     const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
695     motion1 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
696                                  blurred_cur.y_stride,
697                                  CONVERT_TO_SHORTPTR(blurred_last.y_buffer),
698                                  blurred_last.y_stride, y_width, y_height) *
699               scale_factor;
700     if (next) {
701       motion2 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
702                                    blurred_cur.y_stride,
703                                    CONVERT_TO_SHORTPTR(blurred_next.y_buffer),
704                                    blurred_next.y_stride, y_width, y_height) *
705                 scale_factor;
706     }
707   } else {
708     motion1 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
709                           blurred_last.y_buffer, blurred_last.y_stride, y_width,
710                           y_height);
711     if (next) {
712       motion2 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
713                             blurred_next.y_buffer, blurred_next.y_stride,
714                             y_width, y_height);
715     }
716   }
717 
718   aom_free_frame_buffer(&blurred_cur);
719   aom_free_frame_buffer(&blurred_last);
720   aom_free_frame_buffer(&blurred_next);
721 
722   return AOMMIN(motion1, motion2);
723 }
724 
725 // Calculates the new qindex from the VMAF motion score. This is based on the
726 // observation: when the motion score becomes higher, the VMAF score of the
727 // same source and distorted frames would become higher.
av1_get_vmaf_base_qindex(const AV1_COMP * const cpi,int current_qindex)728 int av1_get_vmaf_base_qindex(const AV1_COMP *const cpi, int current_qindex) {
729   const AV1_COMMON *const cm = &cpi->common;
730   if (cm->current_frame.frame_number == 0 || cpi->oxcf.pass == 1) {
731     return current_qindex;
732   }
733   const int bit_depth = cpi->td.mb.e_mbd.bd;
734   const double approx_sse =
735       cpi->last_frame_ysse /
736       (double)((1 << (bit_depth - 8)) * (1 << (bit_depth - 8)));
737   const double approx_dvmaf = kBaselineVmaf - cpi->last_frame_vmaf;
738   const double sse_threshold =
739       0.01 * cpi->source->y_width * cpi->source->y_height;
740   const double vmaf_threshold = 0.01;
741   if (approx_sse < sse_threshold || approx_dvmaf < vmaf_threshold) {
742     return current_qindex;
743   }
744   aom_clear_system_state();
745   const GF_GROUP *gf_group = &cpi->gf_group;
746   YV12_BUFFER_CONFIG *cur_buf = cpi->source;
747   int src_index = 0;
748   if (cm->show_frame == 0) {
749     src_index = gf_group->arf_src_offset[gf_group->index];
750     struct lookahead_entry *cur_entry =
751         av1_lookahead_peek(cpi->lookahead, src_index, cpi->compressor_stage);
752     cur_buf = &cur_entry->img;
753   }
754   assert(cur_buf);
755 
756   const struct lookahead_entry *last_entry =
757       av1_lookahead_peek(cpi->lookahead, src_index - 1, cpi->compressor_stage);
758   const struct lookahead_entry *next_entry =
759       av1_lookahead_peek(cpi->lookahead, src_index + 1, cpi->compressor_stage);
760   const YV12_BUFFER_CONFIG *next_buf = &next_entry->img;
761   const YV12_BUFFER_CONFIG *last_buf =
762       cm->show_frame ? cpi->last_source : &last_entry->img;
763 
764   assert(last_buf);
765 
766   const double motion =
767       calc_vmaf_motion_score(cpi, cm, cur_buf, last_buf, next_buf);
768 
769   // Get dVMAF through a data fitted model.
770   const double dvmaf = 26.11 * (1.0 - exp(-0.06 * motion));
771   const double dsse = dvmaf * approx_sse / approx_dvmaf;
772 
773   const double beta = approx_sse / (dsse + approx_sse);
774   const int offset = av1_get_deltaq_offset(cpi, current_qindex, beta);
775   int qindex = current_qindex + offset;
776 
777   qindex = AOMMIN(qindex, MAXQ);
778   qindex = AOMMAX(qindex, MINQ);
779 
780   aom_clear_system_state();
781   return qindex;
782 }
783 
av1_update_vmaf_curve(AV1_COMP * cpi,YV12_BUFFER_CONFIG * source,YV12_BUFFER_CONFIG * recon)784 void av1_update_vmaf_curve(AV1_COMP *cpi, YV12_BUFFER_CONFIG *source,
785                            YV12_BUFFER_CONFIG *recon) {
786   const int bit_depth = cpi->td.mb.e_mbd.bd;
787   aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, recon, bit_depth,
788                 &cpi->last_frame_vmaf);
789   if (bit_depth > 8) {
790     cpi->last_frame_ysse = (double)aom_highbd_get_y_sse(source, recon);
791   } else {
792     cpi->last_frame_ysse = (double)aom_get_y_sse(source, recon);
793   }
794 }
795