1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 
16 #include "aom_dsp/x86/convolve_avx2.h"
17 #include "aom_dsp/x86/convolve_common_intrin.h"
18 #include "aom_dsp/x86/convolve_sse4_1.h"
19 #include "aom_dsp/aom_dsp_common.h"
20 #include "aom_dsp/aom_filter.h"
21 #include "av1/common/convolve.h"
22 
unpack_weights_avx2(ConvolveParams * conv_params)23 static INLINE __m256i unpack_weights_avx2(ConvolveParams *conv_params) {
24   const int w0 = conv_params->fwd_offset;
25   const int w1 = conv_params->bck_offset;
26   const __m256i wt0 = _mm256_set1_epi16(w0);
27   const __m256i wt1 = _mm256_set1_epi16(w1);
28   const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1);
29   return wt;
30 }
31 
load_line2_avx2(const void * a,const void * b)32 static INLINE __m256i load_line2_avx2(const void *a, const void *b) {
33   return _mm256_permute2x128_si256(
34       _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)a)),
35       _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)b)), 0x20);
36 }
37 
av1_dist_wtd_convolve_x_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params)38 void av1_dist_wtd_convolve_x_avx2(const uint8_t *src, int src_stride,
39                                   uint8_t *dst0, int dst_stride0, int w, int h,
40                                   const InterpFilterParams *filter_params_x,
41                                   const InterpFilterParams *filter_params_y,
42                                   const int subpel_x_q4, const int subpel_y_q4,
43                                   ConvolveParams *conv_params) {
44   CONV_BUF_TYPE *dst = conv_params->dst;
45   int dst_stride = conv_params->dst_stride;
46   const int bd = 8;
47   int i, j, is_horiz_4tap = 0;
48   const int bits = FILTER_BITS - conv_params->round_1;
49   const __m256i wt = unpack_weights_avx2(conv_params);
50   const int do_average = conv_params->do_average;
51   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
52   const int offset_0 =
53       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
54   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
55   const __m256i offset_const = _mm256_set1_epi16(offset);
56   const int rounding_shift =
57       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
58   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
59 
60   assert(bits >= 0);
61   assert(conv_params->round_0 > 0);
62 
63   const __m256i round_const =
64       _mm256_set1_epi16((1 << (conv_params->round_0 - 1)) >> 1);
65   const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_0 - 1);
66 
67   (void)filter_params_y;
68   (void)subpel_y_q4;
69 
70   __m256i filt[4], coeffs[4];
71 
72   filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2);
73   filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32));
74 
75   prepare_coeffs_lowbd(filter_params_x, subpel_x_q4, coeffs);
76 
77   // Condition for checking valid horz_filt taps
78   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0)))
79     is_horiz_4tap = 1;
80 
81   // horz_filt as 4 tap
82   if (is_horiz_4tap) {
83     const int fo_horiz = 1;
84     const uint8_t *const src_ptr = src - fo_horiz;
85     for (i = 0; i < h; i += 2) {
86       const uint8_t *src_data = src_ptr + i * src_stride;
87       CONV_BUF_TYPE *dst_data = dst + i * dst_stride;
88       for (j = 0; j < w; j += 8) {
89         const __m256i data =
90             load_line2_avx2(&src_data[j], &src_data[j + src_stride]);
91 
92         __m256i res = convolve_lowbd_x_4tap(data, coeffs + 1, filt);
93         res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift);
94         res = _mm256_slli_epi16(res, bits);
95 
96         const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
97 
98         // Accumulate values into the destination buffer
99         if (do_average) {
100           const __m256i data_ref_0 =
101               load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]);
102           const __m256i comp_avg_res =
103               comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
104 
105           const __m256i round_result = convolve_rounding(
106               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
107 
108           const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
109           const __m128i res_0 = _mm256_castsi256_si128(res_8);
110           const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
111 
112           if (w > 4) {
113             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
114             _mm_storel_epi64(
115                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
116           } else {
117             *(uint32_t *)(&dst0[i * dst_stride0 + j]) =
118                 _mm_cvtsi128_si32(res_0);
119             *(uint32_t *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
120                 _mm_cvtsi128_si32(res_1);
121           }
122         } else {
123           const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
124           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
125 
126           const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
127           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
128                           res_1);
129         }
130       }
131     }
132   } else {
133     const int fo_horiz = filter_params_x->taps / 2 - 1;
134     const uint8_t *const src_ptr = src - fo_horiz;
135 
136     filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
137     filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
138     for (i = 0; i < h; i += 2) {
139       const uint8_t *src_data = src_ptr + i * src_stride;
140       CONV_BUF_TYPE *dst_data = dst + i * dst_stride;
141       for (j = 0; j < w; j += 8) {
142         const __m256i data =
143             load_line2_avx2(&src_data[j], &src_data[j + src_stride]);
144 
145         __m256i res = convolve_lowbd_x(data, coeffs, filt);
146 
147         res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift);
148 
149         res = _mm256_slli_epi16(res, bits);
150 
151         const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
152 
153         // Accumulate values into the destination buffer
154         if (do_average) {
155           const __m256i data_ref_0 =
156               load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]);
157           const __m256i comp_avg_res =
158               comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
159 
160           const __m256i round_result = convolve_rounding(
161               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
162 
163           const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
164           const __m128i res_0 = _mm256_castsi256_si128(res_8);
165           const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
166 
167           if (w > 4) {
168             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
169             _mm_storel_epi64(
170                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
171           } else {
172             *(uint32_t *)(&dst0[i * dst_stride0 + j]) =
173                 _mm_cvtsi128_si32(res_0);
174             *(uint32_t *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
175                 _mm_cvtsi128_si32(res_1);
176           }
177         } else {
178           const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
179           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
180 
181           const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
182           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
183                           res_1);
184         }
185       }
186     }
187   }
188 }
189 
av1_dist_wtd_convolve_y_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params)190 void av1_dist_wtd_convolve_y_avx2(const uint8_t *src, int src_stride,
191                                   uint8_t *dst0, int dst_stride0, int w, int h,
192                                   const InterpFilterParams *filter_params_x,
193                                   const InterpFilterParams *filter_params_y,
194                                   const int subpel_x_q4, const int subpel_y_q4,
195                                   ConvolveParams *conv_params) {
196   CONV_BUF_TYPE *dst = conv_params->dst;
197   int dst_stride = conv_params->dst_stride;
198   const int bd = 8;
199   int i, j, is_vert_4tap = 0;
200   // +1 to compensate for dividing the filter coeffs by 2
201   const int left_shift = FILTER_BITS - conv_params->round_0 + 1;
202   const __m256i round_const =
203       _mm256_set1_epi32((1 << conv_params->round_1) >> 1);
204   const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
205   const __m256i wt = unpack_weights_avx2(conv_params);
206   const int do_average = conv_params->do_average;
207   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
208   const int offset_0 =
209       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
210   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
211   const __m256i offset_const = _mm256_set1_epi16(offset);
212   const int offset_1 = (1 << (bd + FILTER_BITS - 2));
213   const __m256i offset_const_1 = _mm256_set1_epi16(offset_1);
214   const __m256i offset_const_2 = _mm256_set1_epi16((1 << offset_0));
215   const int rounding_shift =
216       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
217   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
218   const __m256i zero = _mm256_setzero_si256();
219   __m256i coeffs[4], s[8];
220 
221   assert((FILTER_BITS - conv_params->round_0) >= 0);
222 
223   prepare_coeffs_lowbd(filter_params_y, subpel_y_q4, coeffs);
224 
225   (void)conv_params;
226   (void)filter_params_x;
227   (void)subpel_x_q4;
228 
229   // Condition for checking valid vert_filt taps
230   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0)))
231     is_vert_4tap = 1;
232 
233   if (is_vert_4tap) {
234     const int fo_vert = 1;
235     const uint8_t *const src_ptr = src - fo_vert * src_stride;
236     for (j = 0; j < w; j += 16) {
237       const uint8_t *data = &src_ptr[j];
238       __m256i src4;
239       // Load lines a and b. Line a to lower 128, line b to upper 128
240       {
241         __m256i src_ab[4];
242         __m256i src_a[5];
243         src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
244         for (int kk = 0; kk < 4; ++kk) {
245           data += src_stride;
246           src_a[kk + 1] =
247               _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
248           src_ab[kk] =
249               _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20);
250         }
251         src4 = src_a[4];
252         s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]);
253         s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]);
254 
255         s[3] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]);
256         s[4] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]);
257       }
258 
259       for (i = 0; i < h; i += 2) {
260         data = &src_ptr[(i + 5) * src_stride + j];
261         const __m256i src5 =
262             _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
263         const __m256i src_45a = _mm256_permute2x128_si256(src4, src5, 0x20);
264 
265         src4 = _mm256_castsi128_si256(
266             _mm_loadu_si128((__m128i *)(data + src_stride)));
267         const __m256i src_56a = _mm256_permute2x128_si256(src5, src4, 0x20);
268 
269         s[2] = _mm256_unpacklo_epi8(src_45a, src_56a);
270         s[5] = _mm256_unpackhi_epi8(src_45a, src_56a);
271 
272         __m256i res_lo = convolve_lowbd_4tap(s, coeffs + 1);
273 
274         res_lo = _mm256_add_epi16(res_lo, offset_const_1);
275 
276         const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero);
277         const __m256i res_lo_0_shift =
278             _mm256_slli_epi32(res_lo_0_32b, left_shift);
279         const __m256i res_lo_0_round = _mm256_sra_epi32(
280             _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
281 
282         const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero);
283         const __m256i res_lo_1_shift =
284             _mm256_slli_epi32(res_lo_1_32b, left_shift);
285         const __m256i res_lo_1_round = _mm256_sra_epi32(
286             _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
287 
288         const __m256i res_lo_round =
289             _mm256_packs_epi32(res_lo_0_round, res_lo_1_round);
290 
291         const __m256i res_lo_unsigned =
292             _mm256_add_epi16(res_lo_round, offset_const_2);
293 
294         if (w - j < 16) {
295           if (do_average) {
296             const __m256i data_ref_0 =
297                 load_line2_avx2(&dst[i * dst_stride + j],
298                                 &dst[i * dst_stride + j + dst_stride]);
299             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned,
300                                                   &wt, use_dist_wtd_comp_avg);
301 
302             const __m256i round_result = convolve_rounding(
303                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
304 
305             const __m256i res_8 =
306                 _mm256_packus_epi16(round_result, round_result);
307             const __m128i res_0 = _mm256_castsi256_si128(res_8);
308             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
309 
310             if (w - j > 4) {
311               _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
312               _mm_storel_epi64(
313                   (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
314                   res_1);
315             } else {
316               *(uint32_t *)(&dst0[i * dst_stride0 + j]) =
317                   _mm_cvtsi128_si32(res_0);
318               *(uint32_t *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
319                   _mm_cvtsi128_si32(res_1);
320             }
321           } else {
322             const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned);
323             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
324 
325             const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1);
326             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
327                             res_1);
328           }
329         } else {
330           __m256i res_hi = convolve_lowbd_4tap(s + 3, coeffs + 1);
331 
332           res_hi = _mm256_add_epi16(res_hi, offset_const_1);
333 
334           const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero);
335           const __m256i res_hi_0_shift =
336               _mm256_slli_epi32(res_hi_0_32b, left_shift);
337           const __m256i res_hi_0_round = _mm256_sra_epi32(
338               _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
339 
340           const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero);
341           const __m256i res_hi_1_shift =
342               _mm256_slli_epi32(res_hi_1_32b, left_shift);
343           const __m256i res_hi_1_round = _mm256_sra_epi32(
344               _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
345 
346           const __m256i res_hi_round =
347               _mm256_packs_epi32(res_hi_0_round, res_hi_1_round);
348 
349           const __m256i res_hi_unsigned =
350               _mm256_add_epi16(res_hi_round, offset_const_2);
351 
352           if (do_average) {
353             const __m256i data_ref_0_lo =
354                 load_line2_avx2(&dst[i * dst_stride + j],
355                                 &dst[i * dst_stride + j + dst_stride]);
356 
357             const __m256i data_ref_0_hi =
358                 load_line2_avx2(&dst[i * dst_stride + j + 8],
359                                 &dst[i * dst_stride + j + 8 + dst_stride]);
360 
361             const __m256i comp_avg_res_lo = comp_avg(
362                 &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg);
363 
364             const __m256i comp_avg_res_hi = comp_avg(
365                 &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg);
366 
367             const __m256i round_result_lo =
368                 convolve_rounding(&comp_avg_res_lo, &offset_const,
369                                   &rounding_const, rounding_shift);
370 
371             const __m256i round_result_hi =
372                 convolve_rounding(&comp_avg_res_hi, &offset_const,
373                                   &rounding_const, rounding_shift);
374 
375             const __m256i res_8 =
376                 _mm256_packus_epi16(round_result_lo, round_result_hi);
377             const __m128i res_0 = _mm256_castsi256_si128(res_8);
378             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
379 
380             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
381             _mm_store_si128(
382                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
383 
384           } else {
385             const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned);
386             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0);
387 
388             const __m128i res_lo_1 =
389                 _mm256_extracti128_si256(res_lo_unsigned, 1);
390             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
391                             res_lo_1);
392 
393             const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned);
394             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]),
395                             res_hi_0);
396 
397             const __m128i res_hi_1 =
398                 _mm256_extracti128_si256(res_hi_unsigned, 1);
399             _mm_store_si128(
400                 (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]),
401                 res_hi_1);
402           }
403         }
404         s[0] = s[1];
405         s[1] = s[2];
406 
407         s[3] = s[4];
408         s[4] = s[5];
409       }
410     }
411   } else {
412     const int fo_vert = filter_params_y->taps / 2 - 1;
413     const uint8_t *const src_ptr = src - fo_vert * src_stride;
414     for (j = 0; j < w; j += 16) {
415       const uint8_t *data = &src_ptr[j];
416       __m256i src6;
417       // Load lines a and b. Line a to lower 128, line b to upper 128
418       {
419         __m256i src_ab[7];
420         __m256i src_a[7];
421         src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
422         for (int kk = 0; kk < 6; ++kk) {
423           data += src_stride;
424           src_a[kk + 1] =
425               _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
426           src_ab[kk] =
427               _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20);
428         }
429         src6 = src_a[6];
430         s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]);
431         s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]);
432         s[2] = _mm256_unpacklo_epi8(src_ab[4], src_ab[5]);
433         s[4] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]);
434         s[5] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]);
435         s[6] = _mm256_unpackhi_epi8(src_ab[4], src_ab[5]);
436       }
437 
438       for (i = 0; i < h; i += 2) {
439         data = &src_ptr[(i + 7) * src_stride + j];
440         const __m256i src7 =
441             _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
442         const __m256i src_67a = _mm256_permute2x128_si256(src6, src7, 0x20);
443 
444         src6 = _mm256_castsi128_si256(
445             _mm_loadu_si128((__m128i *)(data + src_stride)));
446         const __m256i src_78a = _mm256_permute2x128_si256(src7, src6, 0x20);
447 
448         s[3] = _mm256_unpacklo_epi8(src_67a, src_78a);
449         s[7] = _mm256_unpackhi_epi8(src_67a, src_78a);
450 
451         __m256i res_lo = convolve_lowbd(s, coeffs);
452 
453         res_lo = _mm256_add_epi16(res_lo, offset_const_1);
454 
455         const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero);
456         const __m256i res_lo_0_shift =
457             _mm256_slli_epi32(res_lo_0_32b, left_shift);
458         const __m256i res_lo_0_round = _mm256_sra_epi32(
459             _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
460 
461         const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero);
462         const __m256i res_lo_1_shift =
463             _mm256_slli_epi32(res_lo_1_32b, left_shift);
464         const __m256i res_lo_1_round = _mm256_sra_epi32(
465             _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
466 
467         const __m256i res_lo_round =
468             _mm256_packs_epi32(res_lo_0_round, res_lo_1_round);
469 
470         const __m256i res_lo_unsigned =
471             _mm256_add_epi16(res_lo_round, offset_const_2);
472 
473         if (w - j < 16) {
474           if (do_average) {
475             const __m256i data_ref_0 =
476                 load_line2_avx2(&dst[i * dst_stride + j],
477                                 &dst[i * dst_stride + j + dst_stride]);
478             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned,
479                                                   &wt, use_dist_wtd_comp_avg);
480 
481             const __m256i round_result = convolve_rounding(
482                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
483 
484             const __m256i res_8 =
485                 _mm256_packus_epi16(round_result, round_result);
486             const __m128i res_0 = _mm256_castsi256_si128(res_8);
487             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
488 
489             if (w - j > 4) {
490               _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
491               _mm_storel_epi64(
492                   (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
493                   res_1);
494             } else {
495               *(uint32_t *)(&dst0[i * dst_stride0 + j]) =
496                   _mm_cvtsi128_si32(res_0);
497               *(uint32_t *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
498                   _mm_cvtsi128_si32(res_1);
499             }
500           } else {
501             const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned);
502             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
503 
504             const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1);
505             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
506                             res_1);
507           }
508         } else {
509           __m256i res_hi = convolve_lowbd(s + 4, coeffs);
510 
511           res_hi = _mm256_add_epi16(res_hi, offset_const_1);
512 
513           const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero);
514           const __m256i res_hi_0_shift =
515               _mm256_slli_epi32(res_hi_0_32b, left_shift);
516           const __m256i res_hi_0_round = _mm256_sra_epi32(
517               _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
518 
519           const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero);
520           const __m256i res_hi_1_shift =
521               _mm256_slli_epi32(res_hi_1_32b, left_shift);
522           const __m256i res_hi_1_round = _mm256_sra_epi32(
523               _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
524 
525           const __m256i res_hi_round =
526               _mm256_packs_epi32(res_hi_0_round, res_hi_1_round);
527 
528           const __m256i res_hi_unsigned =
529               _mm256_add_epi16(res_hi_round, offset_const_2);
530 
531           if (do_average) {
532             const __m256i data_ref_0_lo =
533                 load_line2_avx2(&dst[i * dst_stride + j],
534                                 &dst[i * dst_stride + j + dst_stride]);
535 
536             const __m256i data_ref_0_hi =
537                 load_line2_avx2(&dst[i * dst_stride + j + 8],
538                                 &dst[i * dst_stride + j + 8 + dst_stride]);
539 
540             const __m256i comp_avg_res_lo = comp_avg(
541                 &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg);
542 
543             const __m256i comp_avg_res_hi = comp_avg(
544                 &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg);
545 
546             const __m256i round_result_lo =
547                 convolve_rounding(&comp_avg_res_lo, &offset_const,
548                                   &rounding_const, rounding_shift);
549 
550             const __m256i round_result_hi =
551                 convolve_rounding(&comp_avg_res_hi, &offset_const,
552                                   &rounding_const, rounding_shift);
553 
554             const __m256i res_8 =
555                 _mm256_packus_epi16(round_result_lo, round_result_hi);
556             const __m128i res_0 = _mm256_castsi256_si128(res_8);
557             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
558 
559             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
560             _mm_store_si128(
561                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
562 
563           } else {
564             const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned);
565             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0);
566 
567             const __m128i res_lo_1 =
568                 _mm256_extracti128_si256(res_lo_unsigned, 1);
569             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
570                             res_lo_1);
571 
572             const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned);
573             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]),
574                             res_hi_0);
575 
576             const __m128i res_hi_1 =
577                 _mm256_extracti128_si256(res_hi_unsigned, 1);
578             _mm_store_si128(
579                 (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]),
580                 res_hi_1);
581           }
582         }
583         s[0] = s[1];
584         s[1] = s[2];
585         s[2] = s[3];
586 
587         s[4] = s[5];
588         s[5] = s[6];
589         s[6] = s[7];
590       }
591     }
592   }
593 }
594 
av1_dist_wtd_convolve_2d_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params)595 void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
596                                    uint8_t *dst0, int dst_stride0, int w, int h,
597                                    const InterpFilterParams *filter_params_x,
598                                    const InterpFilterParams *filter_params_y,
599                                    const int subpel_x_q4, const int subpel_y_q4,
600                                    ConvolveParams *conv_params) {
601   CONV_BUF_TYPE *dst = conv_params->dst;
602   int dst_stride = conv_params->dst_stride;
603   const int bd = 8;
604 
605   DECLARE_ALIGNED(32, int16_t, im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * 8]);
606 
607   int im_stride = 8;
608   int i, is_horiz_4tap = 0, is_vert_4tap = 0;
609   const __m256i wt = unpack_weights_avx2(conv_params);
610   const int do_average = conv_params->do_average;
611   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
612   const int offset_0 =
613       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
614   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
615   const __m256i offset_const = _mm256_set1_epi16(offset);
616   const int rounding_shift =
617       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
618   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
619 
620   assert(conv_params->round_0 > 0);
621 
622   const __m256i round_const_h = _mm256_set1_epi16(
623       ((1 << (conv_params->round_0 - 1)) >> 1) + (1 << (bd + FILTER_BITS - 2)));
624   const __m128i round_shift_h = _mm_cvtsi32_si128(conv_params->round_0 - 1);
625 
626   const __m256i round_const_v = _mm256_set1_epi32(
627       ((1 << conv_params->round_1) >> 1) -
628       (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1)));
629   const __m128i round_shift_v = _mm_cvtsi32_si128(conv_params->round_1);
630 
631   __m256i filt[4], coeffs_x[4], coeffs_y[4];
632 
633   filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2);
634   filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32));
635 
636   prepare_coeffs_lowbd(filter_params_x, subpel_x_q4, coeffs_x);
637   prepare_coeffs(filter_params_y, subpel_y_q4, coeffs_y);
638 
639   // Condition for checking valid horz_filt taps
640   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_x[0], coeffs_x[3]), 0)))
641     is_horiz_4tap = 1;
642 
643   // Condition for checking valid vert_filt taps
644   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_y[0], coeffs_y[3]), 0)))
645     is_vert_4tap = 1;
646 
647   if (is_horiz_4tap) {
648     int im_h = h + filter_params_y->taps - 1;
649     const int fo_vert = filter_params_y->taps / 2 - 1;
650     const int fo_horiz = 1;
651     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
652     for (int j = 0; j < w; j += 8) {
653       /* Horizontal filter */
654       const uint8_t *src_h = src_ptr + j;
655       for (i = 0; i < im_h; i += 2) {
656         __m256i data =
657             _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)src_h));
658         if (i + 1 < im_h)
659           data = _mm256_inserti128_si256(
660               data, _mm_loadu_si128((__m128i *)(src_h + src_stride)), 1);
661         src_h += (src_stride << 1);
662         __m256i res = convolve_lowbd_x_4tap(data, coeffs_x + 1, filt);
663 
664         res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h),
665                                round_shift_h);
666 
667         _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);
668       }
669       DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP;
670     }
671   } else if (is_vert_4tap) {
672     int im_h = h + 3;
673     const int fo_vert = 1;
674     const int fo_horiz = filter_params_x->taps / 2 - 1;
675     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
676 
677     filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
678     filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
679 
680     for (int j = 0; j < w; j += 8) {
681       /* Horizontal filter */
682       const uint8_t *src_h = src_ptr + j;
683       DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP;
684 
685       /* Vertical filter */
686       __m256i s[6];
687       __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
688       __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
689       __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
690       __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
691 
692       s[0] = _mm256_unpacklo_epi16(s0, s1);
693       s[1] = _mm256_unpacklo_epi16(s2, s3);
694 
695       s[3] = _mm256_unpackhi_epi16(s0, s1);
696       s[4] = _mm256_unpackhi_epi16(s2, s3);
697 
698       for (i = 0; i < h; i += 2) {
699         const int16_t *data = &im_block[i * im_stride];
700 
701         const __m256i s4 =
702             _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
703         const __m256i s5 =
704             _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
705 
706         s[2] = _mm256_unpacklo_epi16(s4, s5);
707         s[5] = _mm256_unpackhi_epi16(s4, s5);
708 
709         const __m256i res_a = convolve_4tap(s, coeffs_y + 1);
710         const __m256i res_a_round = _mm256_sra_epi32(
711             _mm256_add_epi32(res_a, round_const_v), round_shift_v);
712 
713         if (w - j > 4) {
714           const __m256i res_b = convolve_4tap(s + 3, coeffs_y + 1);
715           const __m256i res_b_round = _mm256_sra_epi32(
716               _mm256_add_epi32(res_b, round_const_v), round_shift_v);
717           const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_b_round);
718           const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
719 
720           if (do_average) {
721             const __m256i data_ref_0 =
722                 load_line2_avx2(&dst[i * dst_stride + j],
723                                 &dst[i * dst_stride + j + dst_stride]);
724             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned,
725                                                   &wt, use_dist_wtd_comp_avg);
726 
727             const __m256i round_result = convolve_rounding(
728                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
729 
730             const __m256i res_8 =
731                 _mm256_packus_epi16(round_result, round_result);
732             const __m128i res_0 = _mm256_castsi256_si128(res_8);
733             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
734 
735             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
736             _mm_storel_epi64(
737                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
738           } else {
739             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
740             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
741 
742             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
743             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
744                             res_1);
745           }
746         } else {
747           const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_a_round);
748           const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
749 
750           if (do_average) {
751             const __m256i data_ref_0 =
752                 load_line2_avx2(&dst[i * dst_stride + j],
753                                 &dst[i * dst_stride + j + dst_stride]);
754 
755             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned,
756                                                   &wt, use_dist_wtd_comp_avg);
757 
758             const __m256i round_result = convolve_rounding(
759                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
760 
761             const __m256i res_8 =
762                 _mm256_packus_epi16(round_result, round_result);
763             const __m128i res_0 = _mm256_castsi256_si128(res_8);
764             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
765 
766             *(uint32_t *)(&dst0[i * dst_stride0 + j]) =
767                 _mm_cvtsi128_si32(res_0);
768             *(uint32_t *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
769                 _mm_cvtsi128_si32(res_1);
770 
771           } else {
772             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
773             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
774 
775             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
776             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
777                             res_1);
778           }
779         }
780         s[0] = s[1];
781         s[1] = s[2];
782         s[3] = s[4];
783         s[4] = s[5];
784       }
785     }
786   } else {
787     int im_h = h + filter_params_y->taps - 1;
788     const int fo_vert = filter_params_y->taps / 2 - 1;
789     const int fo_horiz = filter_params_x->taps / 2 - 1;
790     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
791 
792     filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
793     filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
794 
795     for (int j = 0; j < w; j += 8) {
796       /* Horizontal filter */
797       const uint8_t *src_h = src_ptr + j;
798       DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP;
799 
800       DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP;
801     }
802   }
803 }
804 
av1_dist_wtd_convolve_2d_copy_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_q4,const int subpel_y_q4,ConvolveParams * conv_params)805 void av1_dist_wtd_convolve_2d_copy_avx2(
806     const uint8_t *src, int src_stride, uint8_t *dst0, int dst_stride0, int w,
807     int h, const InterpFilterParams *filter_params_x,
808     const InterpFilterParams *filter_params_y, const int subpel_x_q4,
809     const int subpel_y_q4, ConvolveParams *conv_params) {
810   const int bd = 8;
811   CONV_BUF_TYPE *dst = conv_params->dst;
812   int dst_stride = conv_params->dst_stride;
813   (void)filter_params_x;
814   (void)filter_params_y;
815   (void)subpel_x_q4;
816   (void)subpel_y_q4;
817 
818   const int bits =
819       FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
820   const __m128i left_shift = _mm_cvtsi32_si128(bits);
821   const int do_average = conv_params->do_average;
822   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
823   const __m256i wt = unpack_weights_avx2(conv_params);
824   const __m256i zero = _mm256_setzero_si256();
825 
826   const int offset_0 =
827       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
828   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
829   const __m256i offset_const = _mm256_set1_epi16(offset);
830   const int rounding_shift =
831       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
832   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
833   int i, j;
834 
835   if (!(w % 16)) {
836     for (i = 0; i < h; i += 1) {
837       for (j = 0; j < w; j += 16) {
838         const __m256i src_16bit = _mm256_cvtepu8_epi16(
839             _mm_loadu_si128((__m128i *)(&src[i * src_stride + j])));
840 
841         const __m256i res = _mm256_sll_epi16(src_16bit, left_shift);
842         const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
843 
844         if (do_average) {
845           const __m256i data_ref_0 =
846               _mm256_loadu_si256((__m256i *)(&dst[i * dst_stride + j]));
847 
848           const __m256i comp_avg_res =
849               comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
850 
851           const __m256i round_result = convolve_rounding(
852               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
853 
854           const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
855           const __m256i res_0 = _mm256_permute4x64_epi64(res_8, 0xD8);
856 
857           _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]),
858                           _mm256_castsi256_si128(res_0));
859         } else {
860           _mm256_store_si256((__m256i *)(&dst[i * dst_stride + j]),
861                              res_unsigned);
862         }
863       }
864     }
865   } else if (!(w % 4)) {
866     for (i = 0; i < h; i += 2) {
867       for (j = 0; j < w; j += 8) {
868         const __m128i src_row_0 =
869             _mm_loadl_epi64((__m128i *)(&src[i * src_stride + j]));
870         const __m128i src_row_1 =
871             _mm_loadl_epi64((__m128i *)(&src[i * src_stride + j + src_stride]));
872         // since not all compilers yet support _mm256_set_m128i()
873         const __m256i src_10 = _mm256_insertf128_si256(
874             _mm256_castsi128_si256(src_row_0), src_row_1, 1);
875 
876         const __m256i src_16bit = _mm256_unpacklo_epi8(src_10, zero);
877 
878         const __m256i res = _mm256_sll_epi16(src_16bit, left_shift);
879 
880         const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
881 
882         // Accumulate values into the destination buffer
883         if (do_average) {
884           const __m256i data_ref_0 = load_line2_avx2(
885               &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]);
886           const __m256i comp_avg_res =
887               comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
888 
889           const __m256i round_result = convolve_rounding(
890               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
891 
892           const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
893           const __m128i res_0 = _mm256_castsi256_si128(res_8);
894           const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
895 
896           if (w > 4) {
897             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
898             _mm_storel_epi64(
899                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
900           } else {
901             *(uint32_t *)(&dst0[i * dst_stride0 + j]) =
902                 _mm_cvtsi128_si32(res_0);
903             *(uint32_t *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
904                 _mm_cvtsi128_si32(res_1);
905           }
906         } else {
907           const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
908           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
909 
910           const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
911           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
912                           res_1);
913         }
914       }
915     }
916   }
917 }
918