1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 #include <assert.h>
14 
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom_dsp/x86/convolve_avx2.h"
18 #include "aom_dsp/x86/convolve_common_intrin.h"
19 #include "aom_dsp/x86/convolve_sse4_1.h"
20 #include "aom_dsp/x86/synonyms.h"
21 #include "aom_dsp/aom_dsp_common.h"
22 #include "aom_dsp/aom_filter.h"
23 #include "av1/common/convolve.h"
24 
av1_highbd_dist_wtd_convolve_2d_copy_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params,int bd)25 void av1_highbd_dist_wtd_convolve_2d_copy_avx2(
26     const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
27     int h, const InterpFilterParams *filter_params_x,
28     const InterpFilterParams *filter_params_y, const int subpel_x_q4,
29     const int subpel_y_q4, ConvolveParams *conv_params, int bd) {
30   CONV_BUF_TYPE *dst = conv_params->dst;
31   int dst_stride = conv_params->dst_stride;
32   (void)filter_params_x;
33   (void)filter_params_y;
34   (void)subpel_x_q4;
35   (void)subpel_y_q4;
36 
37   const int bits =
38       FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
39   const __m128i left_shift = _mm_cvtsi32_si128(bits);
40   const int do_average = conv_params->do_average;
41   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
42   const int w0 = conv_params->fwd_offset;
43   const int w1 = conv_params->bck_offset;
44   const __m256i wt0 = _mm256_set1_epi32(w0);
45   const __m256i wt1 = _mm256_set1_epi32(w1);
46   const __m256i zero = _mm256_setzero_si256();
47   int i, j;
48 
49   const int offset_0 =
50       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
51   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
52   const __m256i offset_const = _mm256_set1_epi32(offset);
53   const __m256i offset_const_16b = _mm256_set1_epi16(offset);
54   const int rounding_shift =
55       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
56   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
57   const __m256i clip_pixel_to_bd =
58       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
59 
60   assert(bits <= 4);
61 
62   if (!(w % 16)) {
63     for (i = 0; i < h; i += 1) {
64       for (j = 0; j < w; j += 16) {
65         const __m256i src_16bit =
66             _mm256_loadu_si256((__m256i *)(&src[i * src_stride + j]));
67 
68         const __m256i res = _mm256_sll_epi16(src_16bit, left_shift);
69 
70         if (do_average) {
71           const __m256i data_0 =
72               _mm256_loadu_si256((__m256i *)(&dst[i * dst_stride + j]));
73 
74           const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_0, zero);
75           const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_0, zero);
76 
77           const __m256i res_32b_lo = _mm256_unpacklo_epi16(res, zero);
78           const __m256i res_unsigned_lo =
79               _mm256_add_epi32(res_32b_lo, offset_const);
80 
81           const __m256i comp_avg_res_lo =
82               highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
83                               use_dist_wtd_comp_avg);
84 
85           const __m256i res_32b_hi = _mm256_unpackhi_epi16(res, zero);
86           const __m256i res_unsigned_hi =
87               _mm256_add_epi32(res_32b_hi, offset_const);
88 
89           const __m256i comp_avg_res_hi =
90               highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
91                               use_dist_wtd_comp_avg);
92 
93           const __m256i round_result_lo = highbd_convolve_rounding(
94               &comp_avg_res_lo, &offset_const, &rounding_const, rounding_shift);
95           const __m256i round_result_hi = highbd_convolve_rounding(
96               &comp_avg_res_hi, &offset_const, &rounding_const, rounding_shift);
97 
98           const __m256i res_16b =
99               _mm256_packus_epi32(round_result_lo, round_result_hi);
100           const __m256i res_clip = _mm256_min_epi16(res_16b, clip_pixel_to_bd);
101 
102           _mm256_store_si256((__m256i *)(&dst0[i * dst_stride0 + j]), res_clip);
103         } else {
104           const __m256i res_unsigned_16b =
105               _mm256_adds_epu16(res, offset_const_16b);
106 
107           _mm256_store_si256((__m256i *)(&dst[i * dst_stride + j]),
108                              res_unsigned_16b);
109         }
110       }
111     }
112   } else if (!(w % 4)) {
113     for (i = 0; i < h; i += 2) {
114       for (j = 0; j < w; j += 8) {
115         const __m128i src_row_0 =
116             _mm_loadu_si128((__m128i *)(&src[i * src_stride + j]));
117         const __m128i src_row_1 =
118             _mm_loadu_si128((__m128i *)(&src[i * src_stride + j + src_stride]));
119         // since not all compilers yet support _mm256_set_m128i()
120         const __m256i src_10 = _mm256_insertf128_si256(
121             _mm256_castsi128_si256(src_row_0), src_row_1, 1);
122 
123         const __m256i res = _mm256_sll_epi16(src_10, left_shift);
124 
125         if (w - j < 8) {
126           if (do_average) {
127             const __m256i data_0 = _mm256_castsi128_si256(
128                 _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
129             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
130                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
131             const __m256i data_01 =
132                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
133 
134             const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
135 
136             const __m256i res_32b = _mm256_unpacklo_epi16(res, zero);
137             const __m256i res_unsigned_lo =
138                 _mm256_add_epi32(res_32b, offset_const);
139 
140             const __m256i comp_avg_res =
141                 highbd_comp_avg(&data_ref_0, &res_unsigned_lo, &wt0, &wt1,
142                                 use_dist_wtd_comp_avg);
143 
144             const __m256i round_result = highbd_convolve_rounding(
145                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
146 
147             const __m256i res_16b =
148                 _mm256_packus_epi32(round_result, round_result);
149             const __m256i res_clip =
150                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
151 
152             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
153             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
154 
155             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
156             _mm_storel_epi64(
157                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
158           } else {
159             const __m256i res_unsigned_16b =
160                 _mm256_adds_epu16(res, offset_const_16b);
161 
162             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned_16b);
163             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned_16b, 1);
164 
165             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
166             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
167                              res_1);
168           }
169         } else {
170           if (do_average) {
171             const __m256i data_0 = _mm256_castsi128_si256(
172                 _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
173             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
174                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
175             const __m256i data_01 =
176                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
177 
178             const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
179             const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
180 
181             const __m256i res_32b_lo = _mm256_unpacklo_epi16(res, zero);
182             const __m256i res_unsigned_lo =
183                 _mm256_add_epi32(res_32b_lo, offset_const);
184 
185             const __m256i comp_avg_res_lo =
186                 highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
187                                 use_dist_wtd_comp_avg);
188 
189             const __m256i res_32b_hi = _mm256_unpackhi_epi16(res, zero);
190             const __m256i res_unsigned_hi =
191                 _mm256_add_epi32(res_32b_hi, offset_const);
192 
193             const __m256i comp_avg_res_hi =
194                 highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
195                                 use_dist_wtd_comp_avg);
196 
197             const __m256i round_result_lo =
198                 highbd_convolve_rounding(&comp_avg_res_lo, &offset_const,
199                                          &rounding_const, rounding_shift);
200             const __m256i round_result_hi =
201                 highbd_convolve_rounding(&comp_avg_res_hi, &offset_const,
202                                          &rounding_const, rounding_shift);
203 
204             const __m256i res_16b =
205                 _mm256_packus_epi32(round_result_lo, round_result_hi);
206             const __m256i res_clip =
207                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
208 
209             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
210             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
211 
212             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
213             _mm_store_si128(
214                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
215           } else {
216             const __m256i res_unsigned_16b =
217                 _mm256_adds_epu16(res, offset_const_16b);
218             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned_16b);
219             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned_16b, 1);
220 
221             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
222             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
223                             res_1);
224           }
225         }
226       }
227     }
228   }
229 }
230 
av1_highbd_dist_wtd_convolve_2d_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params,int bd)231 void av1_highbd_dist_wtd_convolve_2d_avx2(
232     const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
233     int h, const InterpFilterParams *filter_params_x,
234     const InterpFilterParams *filter_params_y, const int subpel_x_q4,
235     const int subpel_y_q4, ConvolveParams *conv_params, int bd) {
236   DECLARE_ALIGNED(32, int16_t, im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * 8]);
237   CONV_BUF_TYPE *dst = conv_params->dst;
238   int dst_stride = conv_params->dst_stride;
239   int im_h = h + filter_params_y->taps - 1;
240   int im_stride = 8;
241   int i, j;
242   const int fo_vert = filter_params_y->taps / 2 - 1;
243   const int fo_horiz = filter_params_x->taps / 2 - 1;
244   const uint16_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
245 
246   // Check that, even with 12-bit input, the intermediate values will fit
247   // into an unsigned 16-bit intermediate array.
248   assert(bd + FILTER_BITS + 2 - conv_params->round_0 <= 16);
249 
250   __m256i s[8], coeffs_y[4], coeffs_x[4];
251   const int do_average = conv_params->do_average;
252   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
253 
254   const int w0 = conv_params->fwd_offset;
255   const int w1 = conv_params->bck_offset;
256   const __m256i wt0 = _mm256_set1_epi32(w0);
257   const __m256i wt1 = _mm256_set1_epi32(w1);
258   const __m256i zero = _mm256_setzero_si256();
259 
260   const __m256i round_const_x = _mm256_set1_epi32(
261       ((1 << conv_params->round_0) >> 1) + (1 << (bd + FILTER_BITS - 1)));
262   const __m128i round_shift_x = _mm_cvtsi32_si128(conv_params->round_0);
263 
264   const __m256i round_const_y = _mm256_set1_epi32(
265       ((1 << conv_params->round_1) >> 1) -
266       (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1)));
267   const __m128i round_shift_y = _mm_cvtsi32_si128(conv_params->round_1);
268 
269   const int offset_0 =
270       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
271   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
272   const __m256i offset_const = _mm256_set1_epi32(offset);
273   const int rounding_shift =
274       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
275   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
276 
277   const __m256i clip_pixel_to_bd =
278       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
279 
280   prepare_coeffs(filter_params_x, subpel_x_q4, coeffs_x);
281   prepare_coeffs(filter_params_y, subpel_y_q4, coeffs_y);
282 
283   for (j = 0; j < w; j += 8) {
284     /* Horizontal filter */
285     {
286       for (i = 0; i < im_h; i += 2) {
287         const __m256i row0 =
288             _mm256_loadu_si256((__m256i *)&src_ptr[i * src_stride + j]);
289         __m256i row1 = _mm256_set1_epi16(0);
290         if (i + 1 < im_h)
291           row1 =
292               _mm256_loadu_si256((__m256i *)&src_ptr[(i + 1) * src_stride + j]);
293 
294         const __m256i r0 = _mm256_permute2x128_si256(row0, row1, 0x20);
295         const __m256i r1 = _mm256_permute2x128_si256(row0, row1, 0x31);
296 
297         // even pixels
298         s[0] = _mm256_alignr_epi8(r1, r0, 0);
299         s[1] = _mm256_alignr_epi8(r1, r0, 4);
300         s[2] = _mm256_alignr_epi8(r1, r0, 8);
301         s[3] = _mm256_alignr_epi8(r1, r0, 12);
302 
303         __m256i res_even = convolve(s, coeffs_x);
304         res_even = _mm256_sra_epi32(_mm256_add_epi32(res_even, round_const_x),
305                                     round_shift_x);
306 
307         // odd pixels
308         s[0] = _mm256_alignr_epi8(r1, r0, 2);
309         s[1] = _mm256_alignr_epi8(r1, r0, 6);
310         s[2] = _mm256_alignr_epi8(r1, r0, 10);
311         s[3] = _mm256_alignr_epi8(r1, r0, 14);
312 
313         __m256i res_odd = convolve(s, coeffs_x);
314         res_odd = _mm256_sra_epi32(_mm256_add_epi32(res_odd, round_const_x),
315                                    round_shift_x);
316 
317         __m256i res_even1 = _mm256_packs_epi32(res_even, res_even);
318         __m256i res_odd1 = _mm256_packs_epi32(res_odd, res_odd);
319         __m256i res = _mm256_unpacklo_epi16(res_even1, res_odd1);
320 
321         _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);
322       }
323     }
324 
325     /* Vertical filter */
326     {
327       __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
328       __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
329       __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
330       __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
331       __m256i s4 = _mm256_loadu_si256((__m256i *)(im_block + 4 * im_stride));
332       __m256i s5 = _mm256_loadu_si256((__m256i *)(im_block + 5 * im_stride));
333 
334       s[0] = _mm256_unpacklo_epi16(s0, s1);
335       s[1] = _mm256_unpacklo_epi16(s2, s3);
336       s[2] = _mm256_unpacklo_epi16(s4, s5);
337 
338       s[4] = _mm256_unpackhi_epi16(s0, s1);
339       s[5] = _mm256_unpackhi_epi16(s2, s3);
340       s[6] = _mm256_unpackhi_epi16(s4, s5);
341 
342       for (i = 0; i < h; i += 2) {
343         const int16_t *data = &im_block[i * im_stride];
344 
345         const __m256i s6 =
346             _mm256_loadu_si256((__m256i *)(data + 6 * im_stride));
347         const __m256i s7 =
348             _mm256_loadu_si256((__m256i *)(data + 7 * im_stride));
349 
350         s[3] = _mm256_unpacklo_epi16(s6, s7);
351         s[7] = _mm256_unpackhi_epi16(s6, s7);
352 
353         const __m256i res_a = convolve(s, coeffs_y);
354 
355         const __m256i res_a_round = _mm256_sra_epi32(
356             _mm256_add_epi32(res_a, round_const_y), round_shift_y);
357 
358         const __m256i res_unsigned_lo =
359             _mm256_add_epi32(res_a_round, offset_const);
360 
361         if (w - j < 8) {
362           if (do_average) {
363             const __m256i data_0 = _mm256_castsi128_si256(
364                 _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
365             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
366                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
367             const __m256i data_01 =
368                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
369 
370             const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
371 
372             const __m256i comp_avg_res =
373                 highbd_comp_avg(&data_ref_0, &res_unsigned_lo, &wt0, &wt1,
374                                 use_dist_wtd_comp_avg);
375 
376             const __m256i round_result = highbd_convolve_rounding(
377                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
378 
379             const __m256i res_16b =
380                 _mm256_packus_epi32(round_result, round_result);
381             const __m256i res_clip =
382                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
383 
384             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
385             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
386 
387             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
388             _mm_storel_epi64(
389                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
390           } else {
391             __m256i res_16b =
392                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_lo);
393             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
394             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
395 
396             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
397             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
398                              res_1);
399           }
400         } else {
401           const __m256i res_b = convolve(s + 4, coeffs_y);
402           const __m256i res_b_round = _mm256_sra_epi32(
403               _mm256_add_epi32(res_b, round_const_y), round_shift_y);
404 
405           __m256i res_unsigned_hi = _mm256_add_epi32(res_b_round, offset_const);
406 
407           if (do_average) {
408             const __m256i data_0 = _mm256_castsi128_si256(
409                 _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
410             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
411                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
412             const __m256i data_01 =
413                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
414 
415             const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
416             const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
417 
418             const __m256i comp_avg_res_lo =
419                 highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
420                                 use_dist_wtd_comp_avg);
421             const __m256i comp_avg_res_hi =
422                 highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
423                                 use_dist_wtd_comp_avg);
424 
425             const __m256i round_result_lo =
426                 highbd_convolve_rounding(&comp_avg_res_lo, &offset_const,
427                                          &rounding_const, rounding_shift);
428             const __m256i round_result_hi =
429                 highbd_convolve_rounding(&comp_avg_res_hi, &offset_const,
430                                          &rounding_const, rounding_shift);
431 
432             const __m256i res_16b =
433                 _mm256_packus_epi32(round_result_lo, round_result_hi);
434             const __m256i res_clip =
435                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
436 
437             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
438             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
439 
440             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
441             _mm_store_si128(
442                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
443           } else {
444             __m256i res_16b =
445                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_hi);
446             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
447             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
448 
449             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
450             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
451                             res_1);
452           }
453         }
454 
455         s[0] = s[1];
456         s[1] = s[2];
457         s[2] = s[3];
458 
459         s[4] = s[5];
460         s[5] = s[6];
461         s[6] = s[7];
462       }
463     }
464   }
465 }
466 
av1_highbd_dist_wtd_convolve_x_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params,int bd)467 void av1_highbd_dist_wtd_convolve_x_avx2(
468     const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
469     int h, const InterpFilterParams *filter_params_x,
470     const InterpFilterParams *filter_params_y, const int subpel_x_q4,
471     const int subpel_y_q4, ConvolveParams *conv_params, int bd) {
472   CONV_BUF_TYPE *dst = conv_params->dst;
473   int dst_stride = conv_params->dst_stride;
474   const int fo_horiz = filter_params_x->taps / 2 - 1;
475   const uint16_t *const src_ptr = src - fo_horiz;
476   const int bits = FILTER_BITS - conv_params->round_1;
477   (void)filter_params_y;
478   (void)subpel_y_q4;
479 
480   int i, j;
481   __m256i s[4], coeffs_x[4];
482 
483   const int do_average = conv_params->do_average;
484   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
485   const int w0 = conv_params->fwd_offset;
486   const int w1 = conv_params->bck_offset;
487   const __m256i wt0 = _mm256_set1_epi32(w0);
488   const __m256i wt1 = _mm256_set1_epi32(w1);
489   const __m256i zero = _mm256_setzero_si256();
490 
491   const __m256i round_const_x =
492       _mm256_set1_epi32(((1 << conv_params->round_0) >> 1));
493   const __m128i round_shift_x = _mm_cvtsi32_si128(conv_params->round_0);
494   const __m128i round_shift_bits = _mm_cvtsi32_si128(bits);
495 
496   const int offset_0 =
497       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
498   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
499   const __m256i offset_const = _mm256_set1_epi32(offset);
500   const int rounding_shift =
501       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
502   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
503   const __m256i clip_pixel_to_bd =
504       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
505 
506   assert(bits >= 0);
507   prepare_coeffs(filter_params_x, subpel_x_q4, coeffs_x);
508 
509   for (j = 0; j < w; j += 8) {
510     /* Horizontal filter */
511     for (i = 0; i < h; i += 2) {
512       const __m256i row0 =
513           _mm256_loadu_si256((__m256i *)&src_ptr[i * src_stride + j]);
514       __m256i row1 =
515           _mm256_loadu_si256((__m256i *)&src_ptr[(i + 1) * src_stride + j]);
516 
517       const __m256i r0 = _mm256_permute2x128_si256(row0, row1, 0x20);
518       const __m256i r1 = _mm256_permute2x128_si256(row0, row1, 0x31);
519 
520       // even pixels
521       s[0] = _mm256_alignr_epi8(r1, r0, 0);
522       s[1] = _mm256_alignr_epi8(r1, r0, 4);
523       s[2] = _mm256_alignr_epi8(r1, r0, 8);
524       s[3] = _mm256_alignr_epi8(r1, r0, 12);
525 
526       __m256i res_even = convolve(s, coeffs_x);
527       res_even = _mm256_sra_epi32(_mm256_add_epi32(res_even, round_const_x),
528                                   round_shift_x);
529 
530       // odd pixels
531       s[0] = _mm256_alignr_epi8(r1, r0, 2);
532       s[1] = _mm256_alignr_epi8(r1, r0, 6);
533       s[2] = _mm256_alignr_epi8(r1, r0, 10);
534       s[3] = _mm256_alignr_epi8(r1, r0, 14);
535 
536       __m256i res_odd = convolve(s, coeffs_x);
537       res_odd = _mm256_sra_epi32(_mm256_add_epi32(res_odd, round_const_x),
538                                  round_shift_x);
539 
540       res_even = _mm256_sll_epi32(res_even, round_shift_bits);
541       res_odd = _mm256_sll_epi32(res_odd, round_shift_bits);
542 
543       __m256i res1 = _mm256_unpacklo_epi32(res_even, res_odd);
544 
545       __m256i res_unsigned_lo = _mm256_add_epi32(res1, offset_const);
546 
547       if (w - j < 8) {
548         if (do_average) {
549           const __m256i data_0 = _mm256_castsi128_si256(
550               _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
551           const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
552               (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
553           const __m256i data_01 =
554               _mm256_permute2x128_si256(data_0, data_1, 0x20);
555 
556           const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
557 
558           const __m256i comp_avg_res = highbd_comp_avg(
559               &data_ref_0, &res_unsigned_lo, &wt0, &wt1, use_dist_wtd_comp_avg);
560 
561           const __m256i round_result = highbd_convolve_rounding(
562               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
563 
564           const __m256i res_16b =
565               _mm256_packus_epi32(round_result, round_result);
566           const __m256i res_clip = _mm256_min_epi16(res_16b, clip_pixel_to_bd);
567 
568           const __m128i res_0 = _mm256_castsi256_si128(res_clip);
569           const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
570 
571           _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
572           _mm_storel_epi64(
573               (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
574         } else {
575           __m256i res_16b =
576               _mm256_packus_epi32(res_unsigned_lo, res_unsigned_lo);
577           const __m128i res_0 = _mm256_castsi256_si128(res_16b);
578           const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
579 
580           _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
581           _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
582                            res_1);
583         }
584       } else {
585         __m256i res2 = _mm256_unpackhi_epi32(res_even, res_odd);
586         __m256i res_unsigned_hi = _mm256_add_epi32(res2, offset_const);
587 
588         if (do_average) {
589           const __m256i data_0 = _mm256_castsi128_si256(
590               _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
591           const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
592               (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
593           const __m256i data_01 =
594               _mm256_permute2x128_si256(data_0, data_1, 0x20);
595 
596           const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
597           const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
598 
599           const __m256i comp_avg_res_lo =
600               highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
601                               use_dist_wtd_comp_avg);
602           const __m256i comp_avg_res_hi =
603               highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
604                               use_dist_wtd_comp_avg);
605 
606           const __m256i round_result_lo = highbd_convolve_rounding(
607               &comp_avg_res_lo, &offset_const, &rounding_const, rounding_shift);
608           const __m256i round_result_hi = highbd_convolve_rounding(
609               &comp_avg_res_hi, &offset_const, &rounding_const, rounding_shift);
610 
611           const __m256i res_16b =
612               _mm256_packus_epi32(round_result_lo, round_result_hi);
613           const __m256i res_clip = _mm256_min_epi16(res_16b, clip_pixel_to_bd);
614 
615           const __m128i res_0 = _mm256_castsi256_si128(res_clip);
616           const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
617 
618           _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
619           _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]),
620                           res_1);
621         } else {
622           __m256i res_16b =
623               _mm256_packus_epi32(res_unsigned_lo, res_unsigned_hi);
624           const __m128i res_0 = _mm256_castsi256_si128(res_16b);
625           const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
626 
627           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
628           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
629                           res_1);
630         }
631       }
632     }
633   }
634 }
635 
av1_highbd_dist_wtd_convolve_y_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params,int bd)636 void av1_highbd_dist_wtd_convolve_y_avx2(
637     const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
638     int h, const InterpFilterParams *filter_params_x,
639     const InterpFilterParams *filter_params_y, const int subpel_x_q4,
640     const int subpel_y_q4, ConvolveParams *conv_params, int bd) {
641   CONV_BUF_TYPE *dst = conv_params->dst;
642   int dst_stride = conv_params->dst_stride;
643   const int fo_vert = filter_params_y->taps / 2 - 1;
644   const uint16_t *const src_ptr = src - fo_vert * src_stride;
645   const int bits = FILTER_BITS - conv_params->round_0;
646   (void)filter_params_x;
647   (void)subpel_x_q4;
648 
649   assert(bits >= 0);
650   int i, j;
651   __m256i s[8], coeffs_y[4];
652   const int do_average = conv_params->do_average;
653   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
654 
655   const int w0 = conv_params->fwd_offset;
656   const int w1 = conv_params->bck_offset;
657   const __m256i wt0 = _mm256_set1_epi32(w0);
658   const __m256i wt1 = _mm256_set1_epi32(w1);
659   const __m256i round_const_y =
660       _mm256_set1_epi32(((1 << conv_params->round_1) >> 1));
661   const __m128i round_shift_y = _mm_cvtsi32_si128(conv_params->round_1);
662   const __m128i round_shift_bits = _mm_cvtsi32_si128(bits);
663 
664   const int offset_0 =
665       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
666   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
667   const __m256i offset_const = _mm256_set1_epi32(offset);
668   const int rounding_shift =
669       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
670   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
671   const __m256i clip_pixel_to_bd =
672       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
673   const __m256i zero = _mm256_setzero_si256();
674 
675   prepare_coeffs(filter_params_y, subpel_y_q4, coeffs_y);
676 
677   for (j = 0; j < w; j += 8) {
678     const uint16_t *data = &src_ptr[j];
679     /* Vertical filter */
680     {
681       __m256i src6;
682       __m256i s01 = _mm256_permute2x128_si256(
683           _mm256_castsi128_si256(
684               _mm_loadu_si128((__m128i *)(data + 0 * src_stride))),
685           _mm256_castsi128_si256(
686               _mm_loadu_si128((__m128i *)(data + 1 * src_stride))),
687           0x20);
688       __m256i s12 = _mm256_permute2x128_si256(
689           _mm256_castsi128_si256(
690               _mm_loadu_si128((__m128i *)(data + 1 * src_stride))),
691           _mm256_castsi128_si256(
692               _mm_loadu_si128((__m128i *)(data + 2 * src_stride))),
693           0x20);
694       __m256i s23 = _mm256_permute2x128_si256(
695           _mm256_castsi128_si256(
696               _mm_loadu_si128((__m128i *)(data + 2 * src_stride))),
697           _mm256_castsi128_si256(
698               _mm_loadu_si128((__m128i *)(data + 3 * src_stride))),
699           0x20);
700       __m256i s34 = _mm256_permute2x128_si256(
701           _mm256_castsi128_si256(
702               _mm_loadu_si128((__m128i *)(data + 3 * src_stride))),
703           _mm256_castsi128_si256(
704               _mm_loadu_si128((__m128i *)(data + 4 * src_stride))),
705           0x20);
706       __m256i s45 = _mm256_permute2x128_si256(
707           _mm256_castsi128_si256(
708               _mm_loadu_si128((__m128i *)(data + 4 * src_stride))),
709           _mm256_castsi128_si256(
710               _mm_loadu_si128((__m128i *)(data + 5 * src_stride))),
711           0x20);
712       src6 = _mm256_castsi128_si256(
713           _mm_loadu_si128((__m128i *)(data + 6 * src_stride)));
714       __m256i s56 = _mm256_permute2x128_si256(
715           _mm256_castsi128_si256(
716               _mm_loadu_si128((__m128i *)(data + 5 * src_stride))),
717           src6, 0x20);
718 
719       s[0] = _mm256_unpacklo_epi16(s01, s12);
720       s[1] = _mm256_unpacklo_epi16(s23, s34);
721       s[2] = _mm256_unpacklo_epi16(s45, s56);
722 
723       s[4] = _mm256_unpackhi_epi16(s01, s12);
724       s[5] = _mm256_unpackhi_epi16(s23, s34);
725       s[6] = _mm256_unpackhi_epi16(s45, s56);
726 
727       for (i = 0; i < h; i += 2) {
728         data = &src_ptr[i * src_stride + j];
729 
730         const __m256i s67 = _mm256_permute2x128_si256(
731             src6,
732             _mm256_castsi128_si256(
733                 _mm_loadu_si128((__m128i *)(data + 7 * src_stride))),
734             0x20);
735 
736         src6 = _mm256_castsi128_si256(
737             _mm_loadu_si128((__m128i *)(data + 8 * src_stride)));
738 
739         const __m256i s78 = _mm256_permute2x128_si256(
740             _mm256_castsi128_si256(
741                 _mm_loadu_si128((__m128i *)(data + 7 * src_stride))),
742             src6, 0x20);
743 
744         s[3] = _mm256_unpacklo_epi16(s67, s78);
745         s[7] = _mm256_unpackhi_epi16(s67, s78);
746 
747         const __m256i res_a = convolve(s, coeffs_y);
748 
749         __m256i res_a_round = _mm256_sll_epi32(res_a, round_shift_bits);
750         res_a_round = _mm256_sra_epi32(
751             _mm256_add_epi32(res_a_round, round_const_y), round_shift_y);
752 
753         __m256i res_unsigned_lo = _mm256_add_epi32(res_a_round, offset_const);
754 
755         if (w - j < 8) {
756           if (do_average) {
757             const __m256i data_0 = _mm256_castsi128_si256(
758                 _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
759             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
760                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
761             const __m256i data_01 =
762                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
763 
764             const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
765 
766             const __m256i comp_avg_res =
767                 highbd_comp_avg(&data_ref_0, &res_unsigned_lo, &wt0, &wt1,
768                                 use_dist_wtd_comp_avg);
769 
770             const __m256i round_result = highbd_convolve_rounding(
771                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
772 
773             const __m256i res_16b =
774                 _mm256_packus_epi32(round_result, round_result);
775             const __m256i res_clip =
776                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
777 
778             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
779             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
780 
781             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
782             _mm_storel_epi64(
783                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
784           } else {
785             __m256i res_16b =
786                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_lo);
787             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
788             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
789 
790             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
791             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
792                              res_1);
793           }
794         } else {
795           const __m256i res_b = convolve(s + 4, coeffs_y);
796           __m256i res_b_round = _mm256_sll_epi32(res_b, round_shift_bits);
797           res_b_round = _mm256_sra_epi32(
798               _mm256_add_epi32(res_b_round, round_const_y), round_shift_y);
799 
800           __m256i res_unsigned_hi = _mm256_add_epi32(res_b_round, offset_const);
801 
802           if (do_average) {
803             const __m256i data_0 = _mm256_castsi128_si256(
804                 _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
805             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
806                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
807             const __m256i data_01 =
808                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
809 
810             const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
811             const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
812 
813             const __m256i comp_avg_res_lo =
814                 highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
815                                 use_dist_wtd_comp_avg);
816             const __m256i comp_avg_res_hi =
817                 highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
818                                 use_dist_wtd_comp_avg);
819 
820             const __m256i round_result_lo =
821                 highbd_convolve_rounding(&comp_avg_res_lo, &offset_const,
822                                          &rounding_const, rounding_shift);
823             const __m256i round_result_hi =
824                 highbd_convolve_rounding(&comp_avg_res_hi, &offset_const,
825                                          &rounding_const, rounding_shift);
826 
827             const __m256i res_16b =
828                 _mm256_packus_epi32(round_result_lo, round_result_hi);
829             const __m256i res_clip =
830                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
831 
832             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
833             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
834 
835             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
836             _mm_store_si128(
837                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
838           } else {
839             __m256i res_16b =
840                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_hi);
841             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
842             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
843 
844             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
845             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
846                             res_1);
847           }
848         }
849         s[0] = s[1];
850         s[1] = s[2];
851         s[2] = s[3];
852 
853         s[4] = s[5];
854         s[5] = s[6];
855         s[6] = s[7];
856       }
857     }
858   }
859 }
860