1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #ifndef AOM_AOM_DSP_X86_CONVOLVE_AVX2_H_
13 #define AOM_AOM_DSP_X86_CONVOLVE_AVX2_H_
14 
15 // filters for 16
16 DECLARE_ALIGNED(32, static const uint8_t, filt_global_avx2[]) = {
17   0,  1,  1,  2,  2, 3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  0,  1,  1,
18   2,  2,  3,  3,  4, 4,  5,  5,  6,  6,  7,  7,  8,  2,  3,  3,  4,  4,  5,
19   5,  6,  6,  7,  7, 8,  8,  9,  9,  10, 2,  3,  3,  4,  4,  5,  5,  6,  6,
20   7,  7,  8,  8,  9, 9,  10, 4,  5,  5,  6,  6,  7,  7,  8,  8,  9,  9,  10,
21   10, 11, 11, 12, 4, 5,  5,  6,  6,  7,  7,  8,  8,  9,  9,  10, 10, 11, 11,
22   12, 6,  7,  7,  8, 8,  9,  9,  10, 10, 11, 11, 12, 12, 13, 13, 14, 6,  7,
23   7,  8,  8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14
24 };
25 
26 DECLARE_ALIGNED(32, static const uint8_t, filt_d4_global_avx2[]) = {
27   0, 1, 2, 3,  1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, 0, 1, 2, 3,  1, 2,
28   3, 4, 2, 3,  4, 5, 3, 4, 5, 6, 4, 5, 6, 7, 5, 6, 7, 8, 6, 7,  8, 9,
29   7, 8, 9, 10, 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
30 };
31 
32 DECLARE_ALIGNED(32, static const uint8_t, filt4_d4_global_avx2[]) = {
33   2, 3, 4, 5, 3, 4, 5, 6, 4, 5, 6, 7, 5, 6, 7, 8,
34   2, 3, 4, 5, 3, 4, 5, 6, 4, 5, 6, 7, 5, 6, 7, 8,
35 };
36 
37 #define CONVOLVE_SR_HORIZONTAL_FILTER_8TAP                                     \
38   for (i = 0; i < (im_h - 2); i += 2) {                                        \
39     __m256i data = _mm256_castsi128_si256(                                     \
40         _mm_loadu_si128((__m128i *)&src_ptr[(i * src_stride) + j]));           \
41     data = _mm256_inserti128_si256(                                            \
42         data,                                                                  \
43         _mm_loadu_si128(                                                       \
44             (__m128i *)&src_ptr[(i * src_stride) + j + src_stride]),           \
45         1);                                                                    \
46                                                                                \
47     __m256i res = convolve_lowbd_x(data, coeffs_h, filt);                      \
48     res =                                                                      \
49         _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h), round_shift_h); \
50     _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);              \
51   }                                                                            \
52                                                                                \
53   __m256i data_1 = _mm256_castsi128_si256(                                     \
54       _mm_loadu_si128((__m128i *)&src_ptr[(i * src_stride) + j]));             \
55                                                                                \
56   __m256i res = convolve_lowbd_x(data_1, coeffs_h, filt);                      \
57                                                                                \
58   res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h), round_shift_h); \
59                                                                                \
60   _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);
61 
62 #define CONVOLVE_SR_VERTICAL_FILTER_8TAP                                      \
63   __m256i src_0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));  \
64   __m256i src_1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));  \
65   __m256i src_2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));  \
66   __m256i src_3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));  \
67   __m256i src_4 = _mm256_loadu_si256((__m256i *)(im_block + 4 * im_stride));  \
68   __m256i src_5 = _mm256_loadu_si256((__m256i *)(im_block + 5 * im_stride));  \
69                                                                               \
70   __m256i s[8];                                                               \
71   s[0] = _mm256_unpacklo_epi16(src_0, src_1);                                 \
72   s[1] = _mm256_unpacklo_epi16(src_2, src_3);                                 \
73   s[2] = _mm256_unpacklo_epi16(src_4, src_5);                                 \
74                                                                               \
75   s[4] = _mm256_unpackhi_epi16(src_0, src_1);                                 \
76   s[5] = _mm256_unpackhi_epi16(src_2, src_3);                                 \
77   s[6] = _mm256_unpackhi_epi16(src_4, src_5);                                 \
78                                                                               \
79   for (i = 0; i < h; i += 2) {                                                \
80     const int16_t *data = &im_block[i * im_stride];                           \
81                                                                               \
82     const __m256i s6 = _mm256_loadu_si256((__m256i *)(data + 6 * im_stride)); \
83     const __m256i s7 = _mm256_loadu_si256((__m256i *)(data + 7 * im_stride)); \
84                                                                               \
85     s[3] = _mm256_unpacklo_epi16(s6, s7);                                     \
86     s[7] = _mm256_unpackhi_epi16(s6, s7);                                     \
87                                                                               \
88     __m256i res_a = convolve(s, coeffs_v);                                    \
89     __m256i res_b = convolve(s + 4, coeffs_v);                                \
90                                                                               \
91     res_a =                                                                   \
92         _mm256_sra_epi32(_mm256_add_epi32(res_a, sum_round_v), sum_shift_v);  \
93     res_b =                                                                   \
94         _mm256_sra_epi32(_mm256_add_epi32(res_b, sum_round_v), sum_shift_v);  \
95                                                                               \
96     const __m256i res_a_round = _mm256_sra_epi32(                             \
97         _mm256_add_epi32(res_a, round_const_v), round_shift_v);               \
98     const __m256i res_b_round = _mm256_sra_epi32(                             \
99         _mm256_add_epi32(res_b, round_const_v), round_shift_v);               \
100                                                                               \
101     const __m256i res_16bit = _mm256_packs_epi32(res_a_round, res_b_round);   \
102     const __m256i res_8b = _mm256_packus_epi16(res_16bit, res_16bit);         \
103                                                                               \
104     const __m128i res_0 = _mm256_castsi256_si128(res_8b);                     \
105     const __m128i res_1 = _mm256_extracti128_si256(res_8b, 1);                \
106                                                                               \
107     __m128i *const p_0 = (__m128i *)&dst[i * dst_stride + j];                 \
108     __m128i *const p_1 = (__m128i *)&dst[i * dst_stride + j + dst_stride];    \
109     if (w - j > 4) {                                                          \
110       _mm_storel_epi64(p_0, res_0);                                           \
111       _mm_storel_epi64(p_1, res_1);                                           \
112     } else if (w == 4) {                                                      \
113       xx_storel_32(p_0, res_0);                                               \
114       xx_storel_32(p_1, res_1);                                               \
115     } else {                                                                  \
116       *(uint16_t *)p_0 = _mm_cvtsi128_si32(res_0);                            \
117       *(uint16_t *)p_1 = _mm_cvtsi128_si32(res_1);                            \
118     }                                                                         \
119                                                                               \
120     s[0] = s[1];                                                              \
121     s[1] = s[2];                                                              \
122     s[2] = s[3];                                                              \
123                                                                               \
124     s[4] = s[5];                                                              \
125     s[5] = s[6];                                                              \
126     s[6] = s[7];                                                              \
127   }
128 
129 #define DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP                               \
130   for (i = 0; i < im_h; i += 2) {                                              \
131     __m256i data = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)src_h));  \
132     if (i + 1 < im_h)                                                          \
133       data = _mm256_inserti128_si256(                                          \
134           data, _mm_loadu_si128((__m128i *)(src_h + src_stride)), 1);          \
135     src_h += (src_stride << 1);                                                \
136     __m256i res = convolve_lowbd_x(data, coeffs_x, filt);                      \
137                                                                                \
138     res =                                                                      \
139         _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h), round_shift_h); \
140                                                                                \
141     _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);              \
142   }
143 
144 #define DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP                                 \
145   __m256i s[8];                                                                \
146   __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));      \
147   __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));      \
148   __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));      \
149   __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));      \
150   __m256i s4 = _mm256_loadu_si256((__m256i *)(im_block + 4 * im_stride));      \
151   __m256i s5 = _mm256_loadu_si256((__m256i *)(im_block + 5 * im_stride));      \
152                                                                                \
153   s[0] = _mm256_unpacklo_epi16(s0, s1);                                        \
154   s[1] = _mm256_unpacklo_epi16(s2, s3);                                        \
155   s[2] = _mm256_unpacklo_epi16(s4, s5);                                        \
156                                                                                \
157   s[4] = _mm256_unpackhi_epi16(s0, s1);                                        \
158   s[5] = _mm256_unpackhi_epi16(s2, s3);                                        \
159   s[6] = _mm256_unpackhi_epi16(s4, s5);                                        \
160                                                                                \
161   for (i = 0; i < h; i += 2) {                                                 \
162     const int16_t *data = &im_block[i * im_stride];                            \
163                                                                                \
164     const __m256i s6 = _mm256_loadu_si256((__m256i *)(data + 6 * im_stride));  \
165     const __m256i s7 = _mm256_loadu_si256((__m256i *)(data + 7 * im_stride));  \
166                                                                                \
167     s[3] = _mm256_unpacklo_epi16(s6, s7);                                      \
168     s[7] = _mm256_unpackhi_epi16(s6, s7);                                      \
169                                                                                \
170     const __m256i res_a = convolve(s, coeffs_y);                               \
171     const __m256i res_a_round = _mm256_sra_epi32(                              \
172         _mm256_add_epi32(res_a, round_const_v), round_shift_v);                \
173                                                                                \
174     if (w - j > 4) {                                                           \
175       const __m256i res_b = convolve(s + 4, coeffs_y);                         \
176       const __m256i res_b_round = _mm256_sra_epi32(                            \
177           _mm256_add_epi32(res_b, round_const_v), round_shift_v);              \
178       const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_b_round);    \
179       const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);    \
180                                                                                \
181       if (do_average) {                                                        \
182         const __m256i data_ref_0 = load_line2_avx2(                            \
183             &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]);  \
184         const __m256i comp_avg_res =                                           \
185             comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);  \
186                                                                                \
187         const __m256i round_result = convolve_rounding(                        \
188             &comp_avg_res, &offset_const, &rounding_const, rounding_shift);    \
189                                                                                \
190         const __m256i res_8 = _mm256_packus_epi16(round_result, round_result); \
191         const __m128i res_0 = _mm256_castsi256_si128(res_8);                   \
192         const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);              \
193                                                                                \
194         _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);      \
195         _mm_storel_epi64(                                                      \
196             (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);   \
197       } else {                                                                 \
198         const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);            \
199         _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);         \
200                                                                                \
201         const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);       \
202         _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),    \
203                         res_1);                                                \
204       }                                                                        \
205     } else {                                                                   \
206       const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_a_round);    \
207       const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);    \
208                                                                                \
209       if (do_average) {                                                        \
210         const __m256i data_ref_0 = load_line2_avx2(                            \
211             &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]);  \
212                                                                                \
213         const __m256i comp_avg_res =                                           \
214             comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);  \
215                                                                                \
216         const __m256i round_result = convolve_rounding(                        \
217             &comp_avg_res, &offset_const, &rounding_const, rounding_shift);    \
218                                                                                \
219         const __m256i res_8 = _mm256_packus_epi16(round_result, round_result); \
220         const __m128i res_0 = _mm256_castsi256_si128(res_8);                   \
221         const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);              \
222                                                                                \
223         *(uint32_t *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);  \
224         *(uint32_t *)(&dst0[i * dst_stride0 + j + dst_stride0]) =              \
225             _mm_cvtsi128_si32(res_1);                                          \
226                                                                                \
227       } else {                                                                 \
228         const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);            \
229         _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);         \
230                                                                                \
231         const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);       \
232         _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),    \
233                         res_1);                                                \
234       }                                                                        \
235     }                                                                          \
236                                                                                \
237     s[0] = s[1];                                                               \
238     s[1] = s[2];                                                               \
239     s[2] = s[3];                                                               \
240                                                                                \
241     s[4] = s[5];                                                               \
242     s[5] = s[6];                                                               \
243     s[6] = s[7];                                                               \
244   }
prepare_coeffs_lowbd(const InterpFilterParams * const filter_params,const int subpel_q4,__m256i * const coeffs)245 static INLINE void prepare_coeffs_lowbd(
246     const InterpFilterParams *const filter_params, const int subpel_q4,
247     __m256i *const coeffs /* [4] */) {
248   const int16_t *const filter = av1_get_interp_filter_subpel_kernel(
249       filter_params, subpel_q4 & SUBPEL_MASK);
250   const __m128i coeffs_8 = _mm_loadu_si128((__m128i *)filter);
251   const __m256i filter_coeffs = _mm256_broadcastsi128_si256(coeffs_8);
252 
253   // right shift all filter co-efficients by 1 to reduce the bits required.
254   // This extra right shift will be taken care of at the end while rounding
255   // the result.
256   // Since all filter co-efficients are even, this change will not affect the
257   // end result
258   assert(_mm_test_all_zeros(_mm_and_si128(coeffs_8, _mm_set1_epi16(1)),
259                             _mm_set1_epi16(0xffff)));
260 
261   const __m256i coeffs_1 = _mm256_srai_epi16(filter_coeffs, 1);
262 
263   // coeffs 0 1 0 1 0 1 0 1
264   coeffs[0] = _mm256_shuffle_epi8(coeffs_1, _mm256_set1_epi16(0x0200u));
265   // coeffs 2 3 2 3 2 3 2 3
266   coeffs[1] = _mm256_shuffle_epi8(coeffs_1, _mm256_set1_epi16(0x0604u));
267   // coeffs 4 5 4 5 4 5 4 5
268   coeffs[2] = _mm256_shuffle_epi8(coeffs_1, _mm256_set1_epi16(0x0a08u));
269   // coeffs 6 7 6 7 6 7 6 7
270   coeffs[3] = _mm256_shuffle_epi8(coeffs_1, _mm256_set1_epi16(0x0e0cu));
271 }
272 
prepare_coeffs(const InterpFilterParams * const filter_params,const int subpel_q4,__m256i * const coeffs)273 static INLINE void prepare_coeffs(const InterpFilterParams *const filter_params,
274                                   const int subpel_q4,
275                                   __m256i *const coeffs /* [4] */) {
276   const int16_t *filter = av1_get_interp_filter_subpel_kernel(
277       filter_params, subpel_q4 & SUBPEL_MASK);
278 
279   const __m128i coeff_8 = _mm_loadu_si128((__m128i *)filter);
280   const __m256i coeff = _mm256_broadcastsi128_si256(coeff_8);
281 
282   // coeffs 0 1 0 1 0 1 0 1
283   coeffs[0] = _mm256_shuffle_epi32(coeff, 0x00);
284   // coeffs 2 3 2 3 2 3 2 3
285   coeffs[1] = _mm256_shuffle_epi32(coeff, 0x55);
286   // coeffs 4 5 4 5 4 5 4 5
287   coeffs[2] = _mm256_shuffle_epi32(coeff, 0xaa);
288   // coeffs 6 7 6 7 6 7 6 7
289   coeffs[3] = _mm256_shuffle_epi32(coeff, 0xff);
290 }
291 
convolve_lowbd(const __m256i * const s,const __m256i * const coeffs)292 static INLINE __m256i convolve_lowbd(const __m256i *const s,
293                                      const __m256i *const coeffs) {
294   const __m256i res_01 = _mm256_maddubs_epi16(s[0], coeffs[0]);
295   const __m256i res_23 = _mm256_maddubs_epi16(s[1], coeffs[1]);
296   const __m256i res_45 = _mm256_maddubs_epi16(s[2], coeffs[2]);
297   const __m256i res_67 = _mm256_maddubs_epi16(s[3], coeffs[3]);
298 
299   // order: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
300   const __m256i res = _mm256_add_epi16(_mm256_add_epi16(res_01, res_45),
301                                        _mm256_add_epi16(res_23, res_67));
302 
303   return res;
304 }
305 
convolve_lowbd_4tap(const __m256i * const s,const __m256i * const coeffs)306 static INLINE __m256i convolve_lowbd_4tap(const __m256i *const s,
307                                           const __m256i *const coeffs) {
308   const __m256i res_23 = _mm256_maddubs_epi16(s[0], coeffs[0]);
309   const __m256i res_45 = _mm256_maddubs_epi16(s[1], coeffs[1]);
310 
311   // order: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
312   const __m256i res = _mm256_add_epi16(res_45, res_23);
313 
314   return res;
315 }
316 
convolve(const __m256i * const s,const __m256i * const coeffs)317 static INLINE __m256i convolve(const __m256i *const s,
318                                const __m256i *const coeffs) {
319   const __m256i res_0 = _mm256_madd_epi16(s[0], coeffs[0]);
320   const __m256i res_1 = _mm256_madd_epi16(s[1], coeffs[1]);
321   const __m256i res_2 = _mm256_madd_epi16(s[2], coeffs[2]);
322   const __m256i res_3 = _mm256_madd_epi16(s[3], coeffs[3]);
323 
324   const __m256i res = _mm256_add_epi32(_mm256_add_epi32(res_0, res_1),
325                                        _mm256_add_epi32(res_2, res_3));
326 
327   return res;
328 }
329 
convolve_4tap(const __m256i * const s,const __m256i * const coeffs)330 static INLINE __m256i convolve_4tap(const __m256i *const s,
331                                     const __m256i *const coeffs) {
332   const __m256i res_1 = _mm256_madd_epi16(s[0], coeffs[0]);
333   const __m256i res_2 = _mm256_madd_epi16(s[1], coeffs[1]);
334 
335   const __m256i res = _mm256_add_epi32(res_1, res_2);
336   return res;
337 }
338 
convolve_lowbd_x(const __m256i data,const __m256i * const coeffs,const __m256i * const filt)339 static INLINE __m256i convolve_lowbd_x(const __m256i data,
340                                        const __m256i *const coeffs,
341                                        const __m256i *const filt) {
342   __m256i s[4];
343 
344   s[0] = _mm256_shuffle_epi8(data, filt[0]);
345   s[1] = _mm256_shuffle_epi8(data, filt[1]);
346   s[2] = _mm256_shuffle_epi8(data, filt[2]);
347   s[3] = _mm256_shuffle_epi8(data, filt[3]);
348 
349   return convolve_lowbd(s, coeffs);
350 }
351 
convolve_lowbd_x_4tap(const __m256i data,const __m256i * const coeffs,const __m256i * const filt)352 static INLINE __m256i convolve_lowbd_x_4tap(const __m256i data,
353                                             const __m256i *const coeffs,
354                                             const __m256i *const filt) {
355   __m256i s[2];
356 
357   s[0] = _mm256_shuffle_epi8(data, filt[0]);
358   s[1] = _mm256_shuffle_epi8(data, filt[1]);
359 
360   return convolve_lowbd_4tap(s, coeffs);
361 }
362 
add_store_aligned_256(CONV_BUF_TYPE * const dst,const __m256i * const res,const int do_average)363 static INLINE void add_store_aligned_256(CONV_BUF_TYPE *const dst,
364                                          const __m256i *const res,
365                                          const int do_average) {
366   __m256i d;
367   if (do_average) {
368     d = _mm256_load_si256((__m256i *)dst);
369     d = _mm256_add_epi32(d, *res);
370     d = _mm256_srai_epi32(d, 1);
371   } else {
372     d = *res;
373   }
374   _mm256_store_si256((__m256i *)dst, d);
375 }
376 
comp_avg(const __m256i * const data_ref_0,const __m256i * const res_unsigned,const __m256i * const wt,const int use_dist_wtd_comp_avg)377 static INLINE __m256i comp_avg(const __m256i *const data_ref_0,
378                                const __m256i *const res_unsigned,
379                                const __m256i *const wt,
380                                const int use_dist_wtd_comp_avg) {
381   __m256i res;
382   if (use_dist_wtd_comp_avg) {
383     const __m256i data_lo = _mm256_unpacklo_epi16(*data_ref_0, *res_unsigned);
384     const __m256i data_hi = _mm256_unpackhi_epi16(*data_ref_0, *res_unsigned);
385 
386     const __m256i wt_res_lo = _mm256_madd_epi16(data_lo, *wt);
387     const __m256i wt_res_hi = _mm256_madd_epi16(data_hi, *wt);
388 
389     const __m256i res_lo = _mm256_srai_epi32(wt_res_lo, DIST_PRECISION_BITS);
390     const __m256i res_hi = _mm256_srai_epi32(wt_res_hi, DIST_PRECISION_BITS);
391 
392     res = _mm256_packs_epi32(res_lo, res_hi);
393   } else {
394     const __m256i wt_res = _mm256_add_epi16(*data_ref_0, *res_unsigned);
395     res = _mm256_srai_epi16(wt_res, 1);
396   }
397   return res;
398 }
399 
convolve_rounding(const __m256i * const res_unsigned,const __m256i * const offset_const,const __m256i * const round_const,const int round_shift)400 static INLINE __m256i convolve_rounding(const __m256i *const res_unsigned,
401                                         const __m256i *const offset_const,
402                                         const __m256i *const round_const,
403                                         const int round_shift) {
404   const __m256i res_signed = _mm256_sub_epi16(*res_unsigned, *offset_const);
405   const __m256i res_round = _mm256_srai_epi16(
406       _mm256_add_epi16(res_signed, *round_const), round_shift);
407   return res_round;
408 }
409 
highbd_comp_avg(const __m256i * const data_ref_0,const __m256i * const res_unsigned,const __m256i * const wt0,const __m256i * const wt1,const int use_dist_wtd_comp_avg)410 static INLINE __m256i highbd_comp_avg(const __m256i *const data_ref_0,
411                                       const __m256i *const res_unsigned,
412                                       const __m256i *const wt0,
413                                       const __m256i *const wt1,
414                                       const int use_dist_wtd_comp_avg) {
415   __m256i res;
416   if (use_dist_wtd_comp_avg) {
417     const __m256i wt0_res = _mm256_mullo_epi32(*data_ref_0, *wt0);
418     const __m256i wt1_res = _mm256_mullo_epi32(*res_unsigned, *wt1);
419     const __m256i wt_res = _mm256_add_epi32(wt0_res, wt1_res);
420     res = _mm256_srai_epi32(wt_res, DIST_PRECISION_BITS);
421   } else {
422     const __m256i wt_res = _mm256_add_epi32(*data_ref_0, *res_unsigned);
423     res = _mm256_srai_epi32(wt_res, 1);
424   }
425   return res;
426 }
427 
highbd_convolve_rounding(const __m256i * const res_unsigned,const __m256i * const offset_const,const __m256i * const round_const,const int round_shift)428 static INLINE __m256i highbd_convolve_rounding(
429     const __m256i *const res_unsigned, const __m256i *const offset_const,
430     const __m256i *const round_const, const int round_shift) {
431   const __m256i res_signed = _mm256_sub_epi32(*res_unsigned, *offset_const);
432   const __m256i res_round = _mm256_srai_epi32(
433       _mm256_add_epi32(res_signed, *round_const), round_shift);
434 
435   return res_round;
436 }
437 
438 #endif  // AOM_AOM_DSP_X86_CONVOLVE_AVX2_H_
439