1 /*
2  * Copyright (c) 2017, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <smmintrin.h>
14 
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_dsp/aom_filter.h"
19 #include "av1/common/convolve.h"
20 
21 // A specialised version of hfilter, the horizontal filter for
22 // av1_convolve_2d_scale_sse4_1. This version only supports 8 tap filters.
hfilter8(const uint8_t * src,int src_stride,int16_t * dst,int w,int h,int subpel_x_qn,int x_step_qn,const InterpFilterParams * filter_params,unsigned round)23 static void hfilter8(const uint8_t *src, int src_stride, int16_t *dst, int w,
24                      int h, int subpel_x_qn, int x_step_qn,
25                      const InterpFilterParams *filter_params, unsigned round) {
26   const int bd = 8;
27   const int ntaps = 8;
28 
29   src -= ntaps / 2 - 1;
30 
31   int32_t round_add32 = (1 << round) / 2 + (1 << (bd + FILTER_BITS - 1));
32   const __m128i round_add = _mm_set1_epi32(round_add32);
33   const __m128i round_shift = _mm_cvtsi32_si128(round);
34 
35   int x_qn = subpel_x_qn;
36   for (int x = 0; x < w; ++x, x_qn += x_step_qn) {
37     const uint8_t *const src_col = src + (x_qn >> SCALE_SUBPEL_BITS);
38     const int filter_idx = (x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
39     assert(filter_idx < SUBPEL_SHIFTS);
40     const int16_t *filter =
41         av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
42 
43     // Load the filter coefficients
44     const __m128i coefflo = _mm_loadu_si128((__m128i *)filter);
45     const __m128i zero = _mm_castps_si128(_mm_setzero_ps());
46 
47     int y;
48     for (y = 0; y <= h - 4; y += 4) {
49       const uint8_t *const src0 = src_col + y * src_stride;
50       const uint8_t *const src1 = src0 + 1 * src_stride;
51       const uint8_t *const src2 = src0 + 2 * src_stride;
52       const uint8_t *const src3 = src0 + 3 * src_stride;
53 
54       // Load up source data. This is 8-bit input data; each load is just
55       // loading the lower half of the register and gets 8 pixels
56       const __m128i data08 = _mm_loadl_epi64((__m128i *)src0);
57       const __m128i data18 = _mm_loadl_epi64((__m128i *)src1);
58       const __m128i data28 = _mm_loadl_epi64((__m128i *)src2);
59       const __m128i data38 = _mm_loadl_epi64((__m128i *)src3);
60 
61       // Now zero-extend up to 16-bit precision by interleaving with
62       // zeros. Drop the upper half of each register (which just had zeros)
63       const __m128i data0lo = _mm_unpacklo_epi8(data08, zero);
64       const __m128i data1lo = _mm_unpacklo_epi8(data18, zero);
65       const __m128i data2lo = _mm_unpacklo_epi8(data28, zero);
66       const __m128i data3lo = _mm_unpacklo_epi8(data38, zero);
67 
68       // Multiply by coefficients
69       const __m128i conv0lo = _mm_madd_epi16(data0lo, coefflo);
70       const __m128i conv1lo = _mm_madd_epi16(data1lo, coefflo);
71       const __m128i conv2lo = _mm_madd_epi16(data2lo, coefflo);
72       const __m128i conv3lo = _mm_madd_epi16(data3lo, coefflo);
73 
74       // Reduce horizontally and add
75       const __m128i conv01lo = _mm_hadd_epi32(conv0lo, conv1lo);
76       const __m128i conv23lo = _mm_hadd_epi32(conv2lo, conv3lo);
77       const __m128i conv = _mm_hadd_epi32(conv01lo, conv23lo);
78 
79       // Divide down by (1 << round), rounding to nearest.
80       __m128i shifted =
81           _mm_sra_epi32(_mm_add_epi32(conv, round_add), round_shift);
82 
83       shifted = _mm_packus_epi32(shifted, shifted);
84       // Write transposed to the output
85       _mm_storel_epi64((__m128i *)(dst + y + x * h), shifted);
86     }
87     for (; y < h; ++y) {
88       const uint8_t *const src_row = src_col + y * src_stride;
89 
90       int32_t sum = (1 << (bd + FILTER_BITS - 1));
91       for (int k = 0; k < ntaps; ++k) {
92         sum += filter[k] * src_row[k];
93       }
94 
95       dst[y + x * h] = ROUND_POWER_OF_TWO(sum, round);
96     }
97   }
98 }
99 
convolve_16_8(const int16_t * src,__m128i coeff)100 static __m128i convolve_16_8(const int16_t *src, __m128i coeff) {
101   __m128i data = _mm_loadu_si128((__m128i *)src);
102   return _mm_madd_epi16(data, coeff);
103 }
104 
105 // A specialised version of vfilter, the vertical filter for
106 // av1_convolve_2d_scale_sse4_1. This version only supports 8 tap filters.
vfilter8(const int16_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,int subpel_y_qn,int y_step_qn,const InterpFilterParams * filter_params,const ConvolveParams * conv_params,int bd)107 static void vfilter8(const int16_t *src, int src_stride, uint8_t *dst,
108                      int dst_stride, int w, int h, int subpel_y_qn,
109                      int y_step_qn, const InterpFilterParams *filter_params,
110                      const ConvolveParams *conv_params, int bd) {
111   const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
112   const int ntaps = 8;
113 
114   const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
115 
116   const int32_t sub32 = ((1 << (offset_bits - conv_params->round_1)) +
117                          (1 << (offset_bits - conv_params->round_1 - 1)));
118   const __m128i sub = _mm_set1_epi16(sub32);
119 
120   CONV_BUF_TYPE *dst16 = conv_params->dst;
121   const int dst16_stride = conv_params->dst_stride;
122   const int bits =
123       FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
124   const __m128i bits_shift = _mm_cvtsi32_si128(bits);
125   const __m128i bits_const = _mm_set1_epi16(((1 << bits) >> 1));
126   const __m128i round_shift_add =
127       _mm_set1_epi32(((1 << conv_params->round_1) >> 1));
128   const __m128i res_add_const = _mm_set1_epi32(1 << offset_bits);
129 
130   const int w0 = conv_params->fwd_offset;
131   const int w1 = conv_params->bck_offset;
132   const __m128i wt0 = _mm_set1_epi16(w0);
133   const __m128i wt1 = _mm_set1_epi16(w1);
134   const __m128i wt = _mm_unpacklo_epi16(wt0, wt1);
135 
136   int y_qn = subpel_y_qn;
137   for (int y = 0; y < h; ++y, y_qn += y_step_qn) {
138     const int16_t *src_y = src + (y_qn >> SCALE_SUBPEL_BITS);
139     const int filter_idx = (y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
140     assert(filter_idx < SUBPEL_SHIFTS);
141     const int16_t *filter =
142         av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
143 
144     const __m128i coeff0716 = _mm_loadu_si128((__m128i *)filter);
145     int x;
146     for (x = 0; x <= w - 4; x += 4) {
147       const int16_t *const src0 = src_y + x * src_stride;
148       const int16_t *const src1 = src0 + 1 * src_stride;
149       const int16_t *const src2 = src0 + 2 * src_stride;
150       const int16_t *const src3 = src0 + 3 * src_stride;
151 
152       // Load the source data for the three rows, adding the three registers of
153       // convolved products to one as we go (conv0..conv3) to avoid the
154       // register pressure getting too high.
155       const __m128i conv0 = convolve_16_8(src0, coeff0716);
156       const __m128i conv1 = convolve_16_8(src1, coeff0716);
157       const __m128i conv2 = convolve_16_8(src2, coeff0716);
158       const __m128i conv3 = convolve_16_8(src3, coeff0716);
159 
160       // Now reduce horizontally to get one lane for each result
161       const __m128i conv01 = _mm_hadd_epi32(conv0, conv1);
162       const __m128i conv23 = _mm_hadd_epi32(conv2, conv3);
163       __m128i conv = _mm_hadd_epi32(conv01, conv23);
164 
165       conv = _mm_add_epi32(conv, res_add_const);
166       // Divide down by (1 << round_1), rounding to nearest and subtract sub32.
167       __m128i shifted =
168           _mm_sra_epi32(_mm_add_epi32(conv, round_shift_add), round_shift);
169 
170       uint8_t *dst_x = dst + y * dst_stride + x;
171       CONV_BUF_TYPE *dst_16_x = dst16 + y * dst16_stride + x;
172       __m128i result;
173       __m128i shifted_16 = _mm_packus_epi32(shifted, shifted);
174 
175       if (conv_params->is_compound) {
176         if (conv_params->do_average) {
177           const __m128i p_16 = _mm_loadl_epi64((__m128i *)dst_16_x);
178           if (conv_params->use_dist_wtd_comp_avg) {
179             const __m128i p_16_lo = _mm_unpacklo_epi16(p_16, shifted_16);
180             const __m128i wt_res_lo = _mm_madd_epi16(p_16_lo, wt);
181             const __m128i shifted_32 =
182                 _mm_srai_epi32(wt_res_lo, DIST_PRECISION_BITS);
183             shifted_16 = _mm_packus_epi32(shifted_32, shifted_32);
184           } else {
185             shifted_16 = _mm_srai_epi16(_mm_add_epi16(p_16, shifted_16), 1);
186           }
187           const __m128i subbed = _mm_sub_epi16(shifted_16, sub);
188           result = _mm_sra_epi16(_mm_add_epi16(subbed, bits_const), bits_shift);
189           const __m128i result_8 = _mm_packus_epi16(result, result);
190           *(uint32_t *)dst_x = _mm_cvtsi128_si32(result_8);
191         } else {
192           _mm_storel_epi64((__m128i *)dst_16_x, shifted_16);
193         }
194       } else {
195         const __m128i subbed = _mm_sub_epi16(shifted_16, sub);
196         result = _mm_sra_epi16(_mm_add_epi16(subbed, bits_const), bits_shift);
197         const __m128i result_8 = _mm_packus_epi16(result, result);
198         *(uint32_t *)dst_x = _mm_cvtsi128_si32(result_8);
199       }
200     }
201     for (; x < w; ++x) {
202       const int16_t *src_x = src_y + x * src_stride;
203       int32_t sum = 1 << offset_bits;
204       for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
205       CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
206 
207       if (conv_params->is_compound) {
208         if (conv_params->do_average) {
209           int32_t tmp = dst16[y * dst16_stride + x];
210           if (conv_params->use_dist_wtd_comp_avg) {
211             tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
212             tmp = tmp >> DIST_PRECISION_BITS;
213           } else {
214             tmp += res;
215             tmp = tmp >> 1;
216           }
217           /* Subtract round offset and convolve round */
218           tmp = tmp - sub32;
219           dst[y * dst_stride + x] = clip_pixel(ROUND_POWER_OF_TWO(tmp, bits));
220         } else {
221           dst16[y * dst16_stride + x] = res;
222         }
223       } else {
224         /* Subtract round offset and convolve round */
225         int32_t tmp = res - ((1 << (offset_bits - conv_params->round_1)) +
226                              (1 << (offset_bits - conv_params->round_1 - 1)));
227         dst[y * dst_stride + x] = clip_pixel(ROUND_POWER_OF_TWO(tmp, bits));
228       }
229     }
230   }
231 }
av1_convolve_2d_scale_sse4_1(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int x_step_qn,const int subpel_y_qn,const int y_step_qn,ConvolveParams * conv_params)232 void av1_convolve_2d_scale_sse4_1(const uint8_t *src, int src_stride,
233                                   uint8_t *dst8, int dst8_stride, int w, int h,
234                                   const InterpFilterParams *filter_params_x,
235                                   const InterpFilterParams *filter_params_y,
236                                   const int subpel_x_qn, const int x_step_qn,
237                                   const int subpel_y_qn, const int y_step_qn,
238                                   ConvolveParams *conv_params) {
239   // TODO(yaowu): remove unnecessary initializations
240   int16_t tmp[(2 * MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE] = { 0 };
241   int im_h = (((h - 1) * y_step_qn + subpel_y_qn) >> SCALE_SUBPEL_BITS) +
242              filter_params_y->taps;
243 
244   const int xtaps = filter_params_x->taps;
245   const int ytaps = filter_params_y->taps;
246   const int fo_vert = ytaps / 2 - 1;
247   assert((xtaps == 8) && (ytaps == 8));
248   (void)xtaps;
249 
250   // horizontal filter
251   hfilter8(src - fo_vert * src_stride, src_stride, tmp, w, im_h, subpel_x_qn,
252            x_step_qn, filter_params_x, conv_params->round_0);
253 
254   // vertical filter (input is transposed)
255   vfilter8(tmp, im_h, dst8, dst8_stride, w, h, subpel_y_qn, y_step_qn,
256            filter_params_y, conv_params, 8);
257 }
258 
259 // A specialised version of hfilter, the horizontal filter for
260 // av1_highbd_convolve_2d_scale_sse4_1. This version only supports 8 tap
261 // filters.
highbd_hfilter8(const uint16_t * src,int src_stride,int16_t * dst,int w,int h,int subpel_x_qn,int x_step_qn,const InterpFilterParams * filter_params,unsigned round,int bd)262 static void highbd_hfilter8(const uint16_t *src, int src_stride, int16_t *dst,
263                             int w, int h, int subpel_x_qn, int x_step_qn,
264                             const InterpFilterParams *filter_params,
265                             unsigned round, int bd) {
266   const int ntaps = 8;
267 
268   src -= ntaps / 2 - 1;
269 
270   int32_t round_add32 = (1 << round) / 2 + (1 << (bd + FILTER_BITS - 1));
271   const __m128i round_add = _mm_set1_epi32(round_add32);
272   const __m128i round_shift = _mm_cvtsi32_si128(round);
273 
274   int x_qn = subpel_x_qn;
275   for (int x = 0; x < w; ++x, x_qn += x_step_qn) {
276     const uint16_t *const src_col = src + (x_qn >> SCALE_SUBPEL_BITS);
277     const int filter_idx = (x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
278     assert(filter_idx < SUBPEL_SHIFTS);
279     const int16_t *filter =
280         av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
281 
282     // Load the filter coefficients
283     const __m128i coefflo = _mm_loadu_si128((__m128i *)filter);
284 
285     int y;
286     for (y = 0; y <= h - 4; y += 4) {
287       const uint16_t *const src0 = src_col + y * src_stride;
288       const uint16_t *const src1 = src0 + 1 * src_stride;
289       const uint16_t *const src2 = src0 + 2 * src_stride;
290       const uint16_t *const src3 = src0 + 3 * src_stride;
291 
292       // Load up source data. This is 16-bit input data, so each load gets the 8
293       // pixels we need.
294       const __m128i data0lo = _mm_loadu_si128((__m128i *)src0);
295       const __m128i data1lo = _mm_loadu_si128((__m128i *)src1);
296       const __m128i data2lo = _mm_loadu_si128((__m128i *)src2);
297       const __m128i data3lo = _mm_loadu_si128((__m128i *)src3);
298 
299       // Multiply by coefficients
300       const __m128i conv0lo = _mm_madd_epi16(data0lo, coefflo);
301       const __m128i conv1lo = _mm_madd_epi16(data1lo, coefflo);
302       const __m128i conv2lo = _mm_madd_epi16(data2lo, coefflo);
303       const __m128i conv3lo = _mm_madd_epi16(data3lo, coefflo);
304 
305       // Reduce horizontally and add
306       const __m128i conv01lo = _mm_hadd_epi32(conv0lo, conv1lo);
307       const __m128i conv23lo = _mm_hadd_epi32(conv2lo, conv3lo);
308       const __m128i conv = _mm_hadd_epi32(conv01lo, conv23lo);
309 
310       // Divide down by (1 << round), rounding to nearest.
311       __m128i shifted =
312           _mm_sra_epi32(_mm_add_epi32(conv, round_add), round_shift);
313 
314       shifted = _mm_packus_epi32(shifted, shifted);
315       // Write transposed to the output
316       _mm_storel_epi64((__m128i *)(dst + y + x * h), shifted);
317     }
318     for (; y < h; ++y) {
319       const uint16_t *const src_row = src_col + y * src_stride;
320 
321       int32_t sum = (1 << (bd + FILTER_BITS - 1));
322       for (int k = 0; k < ntaps; ++k) {
323         sum += filter[k] * src_row[k];
324       }
325 
326       dst[y + x * h] = ROUND_POWER_OF_TWO(sum, round);
327     }
328   }
329 }
330 // A specialised version of vfilter, the vertical filter for
331 // av1_highbd_convolve_2d_scale_sse4_1. This version only supports 8 tap
332 // filters.
highbd_vfilter8(const int16_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,int subpel_y_qn,int y_step_qn,const InterpFilterParams * filter_params,const ConvolveParams * conv_params,int bd)333 static void highbd_vfilter8(const int16_t *src, int src_stride, uint16_t *dst,
334                             int dst_stride, int w, int h, int subpel_y_qn,
335                             int y_step_qn,
336                             const InterpFilterParams *filter_params,
337                             const ConvolveParams *conv_params, int bd) {
338   const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
339   const int ntaps = 8;
340 
341   const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
342 
343   const int32_t sub32 = ((1 << (offset_bits - conv_params->round_1)) +
344                          (1 << (offset_bits - conv_params->round_1 - 1)));
345   const __m128i sub = _mm_set1_epi32(sub32);
346 
347   CONV_BUF_TYPE *dst16 = conv_params->dst;
348   const int dst16_stride = conv_params->dst_stride;
349   const __m128i clip_pixel_ =
350       _mm_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
351   const int bits =
352       FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
353   const __m128i bits_shift = _mm_cvtsi32_si128(bits);
354   const __m128i bits_const = _mm_set1_epi32(((1 << bits) >> 1));
355   const __m128i round_shift_add =
356       _mm_set1_epi32(((1 << conv_params->round_1) >> 1));
357   const __m128i res_add_const = _mm_set1_epi32(1 << offset_bits);
358   const int round_bits =
359       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
360   __m128i round_bits_shift = _mm_cvtsi32_si128(round_bits);
361   __m128i round_bits_const = _mm_set1_epi32(((1 << round_bits) >> 1));
362 
363   const int w0 = conv_params->fwd_offset;
364   const int w1 = conv_params->bck_offset;
365   const __m128i wt0 = _mm_set1_epi32(w0);
366   const __m128i wt1 = _mm_set1_epi32(w1);
367 
368   int y_qn = subpel_y_qn;
369   for (int y = 0; y < h; ++y, y_qn += y_step_qn) {
370     const int16_t *src_y = src + (y_qn >> SCALE_SUBPEL_BITS);
371     const int filter_idx = (y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
372     assert(filter_idx < SUBPEL_SHIFTS);
373     const int16_t *filter =
374         av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
375 
376     const __m128i coeff0716 = _mm_loadu_si128((__m128i *)filter);
377     int x;
378     for (x = 0; x <= w - 4; x += 4) {
379       const int16_t *const src0 = src_y + x * src_stride;
380       const int16_t *const src1 = src0 + 1 * src_stride;
381       const int16_t *const src2 = src0 + 2 * src_stride;
382       const int16_t *const src3 = src0 + 3 * src_stride;
383 
384       // Load the source data for the three rows, adding the three registers of
385       // convolved products to one as we go (conv0..conv3) to avoid the
386       // register pressure getting too high.
387       const __m128i conv0 = convolve_16_8(src0, coeff0716);
388       const __m128i conv1 = convolve_16_8(src1, coeff0716);
389       const __m128i conv2 = convolve_16_8(src2, coeff0716);
390       const __m128i conv3 = convolve_16_8(src3, coeff0716);
391 
392       // Now reduce horizontally to get one lane for each result
393       const __m128i conv01 = _mm_hadd_epi32(conv0, conv1);
394       const __m128i conv23 = _mm_hadd_epi32(conv2, conv3);
395       __m128i conv = _mm_hadd_epi32(conv01, conv23);
396       conv = _mm_add_epi32(conv, res_add_const);
397 
398       // Divide down by (1 << round_1), rounding to nearest and subtract sub32.
399       __m128i shifted =
400           _mm_sra_epi32(_mm_add_epi32(conv, round_shift_add), round_shift);
401 
402       uint16_t *dst_x = dst + y * dst_stride + x;
403       CONV_BUF_TYPE *dst_16_x = dst16 + y * dst16_stride + x;
404 
405       __m128i result;
406       if (conv_params->is_compound) {
407         if (conv_params->do_average) {
408           __m128i p_32 =
409               _mm_cvtepu16_epi32(_mm_loadl_epi64((__m128i *)dst_16_x));
410 
411           if (conv_params->use_dist_wtd_comp_avg) {
412             shifted = _mm_add_epi32(_mm_mullo_epi32(p_32, wt0),
413                                     _mm_mullo_epi32(shifted, wt1));
414             shifted = _mm_srai_epi32(shifted, DIST_PRECISION_BITS);
415           } else {
416             shifted = _mm_srai_epi32(_mm_add_epi32(p_32, shifted), 1);
417           }
418           __m128i res32 = _mm_sub_epi32(shifted, sub);
419           res32 = _mm_sra_epi32(_mm_add_epi32(res32, round_bits_const),
420                                 round_bits_shift);
421 
422           __m128i res16 = _mm_packus_epi32(res32, res32);
423           res16 = _mm_min_epi16(res16, clip_pixel_);
424           _mm_storel_epi64((__m128i *)dst_x, res16);
425         } else {
426           __m128i shifted_16 = _mm_packus_epi32(shifted, shifted);
427           _mm_storel_epi64((__m128i *)dst_16_x, shifted_16);
428         }
429       } else {
430         const __m128i subbed = _mm_sub_epi32(shifted, sub);
431         result = _mm_sra_epi16(_mm_add_epi32(subbed, bits_const), bits_shift);
432         result = _mm_packus_epi32(result, result);
433         result = _mm_min_epi16(result, clip_pixel_);
434         _mm_storel_epi64((__m128i *)dst_x, result);
435       }
436     }
437 
438     for (; x < w; ++x) {
439       const int16_t *src_x = src_y + x * src_stride;
440       int32_t sum = 1 << offset_bits;
441       for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
442       CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
443       if (conv_params->is_compound) {
444         if (conv_params->do_average) {
445           int32_t tmp = dst16[y * dst16_stride + x];
446           if (conv_params->use_dist_wtd_comp_avg) {
447             tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
448             tmp = tmp >> DIST_PRECISION_BITS;
449           } else {
450             tmp += res;
451             tmp = tmp >> 1;
452           }
453           /* Subtract round offset and convolve round */
454           tmp = tmp - ((1 << (offset_bits - conv_params->round_1)) +
455                        (1 << (offset_bits - conv_params->round_1 - 1)));
456           dst[y * dst_stride + x] =
457               clip_pixel_highbd(ROUND_POWER_OF_TWO(tmp, bits), bd);
458         } else {
459           dst16[y * dst16_stride + x] = res;
460         }
461       } else {
462         /* Subtract round offset and convolve round */
463         int32_t tmp = res - ((1 << (offset_bits - conv_params->round_1)) +
464                              (1 << (offset_bits - conv_params->round_1 - 1)));
465         dst[y * dst_stride + x] =
466             clip_pixel_highbd(ROUND_POWER_OF_TWO(tmp, bits), bd);
467       }
468     }
469   }
470 }
471 
av1_highbd_convolve_2d_scale_sse4_1(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int x_step_qn,const int subpel_y_qn,const int y_step_qn,ConvolveParams * conv_params,int bd)472 void av1_highbd_convolve_2d_scale_sse4_1(
473     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
474     int h, const InterpFilterParams *filter_params_x,
475     const InterpFilterParams *filter_params_y, const int subpel_x_qn,
476     const int x_step_qn, const int subpel_y_qn, const int y_step_qn,
477     ConvolveParams *conv_params, int bd) {
478   // TODO(yaowu): Move this out of stack
479   DECLARE_ALIGNED(16, int16_t,
480                   tmp[(2 * MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
481   int im_h = (((h - 1) * y_step_qn + subpel_y_qn) >> SCALE_SUBPEL_BITS) +
482              filter_params_y->taps;
483   const int xtaps = filter_params_x->taps;
484   const int ytaps = filter_params_y->taps;
485   const int fo_vert = ytaps / 2 - 1;
486 
487   memset(tmp, 0, sizeof(tmp));
488   assert((xtaps == 8) && (ytaps == 8));
489   (void)xtaps;
490 
491   // horizontal filter
492   highbd_hfilter8(src - fo_vert * src_stride, src_stride, tmp, w, im_h,
493                   subpel_x_qn, x_step_qn, filter_params_x, conv_params->round_0,
494                   bd);
495 
496   // vertical filter (input is transposed)
497   highbd_vfilter8(tmp, im_h, dst, dst_stride, w, h, subpel_y_qn, y_step_qn,
498                   filter_params_y, conv_params, bd);
499 }
500