1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 
16 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
17 
mm256_add_hi_lo_epi16(const __m256i val)18 static INLINE __m128i mm256_add_hi_lo_epi16(const __m256i val) {
19   return _mm_add_epi16(_mm256_castsi256_si128(val),
20                        _mm256_extractf128_si256(val, 1));
21 }
22 
mm256_add_hi_lo_epi32(const __m256i val)23 static INLINE __m128i mm256_add_hi_lo_epi32(const __m256i val) {
24   return _mm_add_epi32(_mm256_castsi256_si128(val),
25                        _mm256_extractf128_si256(val, 1));
26 }
27 
variance_kernel_avx2(const __m256i src,const __m256i ref,__m256i * const sse,__m256i * const sum)28 static INLINE void variance_kernel_avx2(const __m256i src, const __m256i ref,
29                                         __m256i *const sse,
30                                         __m256i *const sum) {
31   const __m256i adj_sub = _mm256_set1_epi16(0xff01);  // (1,-1)
32 
33   // unpack into pairs of source and reference values
34   const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
35   const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
36 
37   // subtract adjacent elements using src*1 + ref*-1
38   const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
39   const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
40   const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
41   const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
42 
43   // add to the running totals
44   *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1));
45   *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
46 }
47 
variance_final_from_32bit_sum_avx2(__m256i vsse,__m128i vsum,unsigned int * const sse)48 static INLINE int variance_final_from_32bit_sum_avx2(__m256i vsse, __m128i vsum,
49                                                      unsigned int *const sse) {
50   // extract the low lane and add it to the high lane
51   const __m128i sse_reg_128 = mm256_add_hi_lo_epi32(vsse);
52 
53   // unpack sse and sum registers and add
54   const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum);
55   const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum);
56   const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
57 
58   // perform the final summation and extract the results
59   const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
60   *((int *)sse) = _mm_cvtsi128_si32(res);
61   return _mm_extract_epi32(res, 1);
62 }
63 
64 // handle pixels (<= 512)
variance_final_512_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)65 static INLINE int variance_final_512_avx2(__m256i vsse, __m256i vsum,
66                                           unsigned int *const sse) {
67   // extract the low lane and add it to the high lane
68   const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
69   const __m128i vsum_64 = _mm_add_epi16(vsum_128, _mm_srli_si128(vsum_128, 8));
70   const __m128i sum_int32 = _mm_cvtepi16_epi32(vsum_64);
71   return variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse);
72 }
73 
74 // handle 1024 pixels (32x32, 16x64, 64x16)
variance_final_1024_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)75 static INLINE int variance_final_1024_avx2(__m256i vsse, __m256i vsum,
76                                            unsigned int *const sse) {
77   // extract the low lane and add it to the high lane
78   const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
79   const __m128i vsum_64 =
80       _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128),
81                     _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8)));
82   return variance_final_from_32bit_sum_avx2(vsse, vsum_64, sse);
83 }
84 
sum_to_32bit_avx2(const __m256i sum)85 static INLINE __m256i sum_to_32bit_avx2(const __m256i sum) {
86   const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum));
87   const __m256i sum_hi =
88       _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1));
89   return _mm256_add_epi32(sum_lo, sum_hi);
90 }
91 
92 // handle 2048 pixels (32x64, 64x32)
variance_final_2048_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)93 static INLINE int variance_final_2048_avx2(__m256i vsse, __m256i vsum,
94                                            unsigned int *const sse) {
95   vsum = sum_to_32bit_avx2(vsum);
96   const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);
97   return variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse);
98 }
99 
variance16_kernel_avx2(const uint8_t * const src,const int src_stride,const uint8_t * const ref,const int ref_stride,__m256i * const sse,__m256i * const sum)100 static INLINE void variance16_kernel_avx2(
101     const uint8_t *const src, const int src_stride, const uint8_t *const ref,
102     const int ref_stride, __m256i *const sse, __m256i *const sum) {
103   const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
104   const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
105   const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
106   const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
107   const __m256i s = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
108   const __m256i r = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
109   variance_kernel_avx2(s, r, sse, sum);
110 }
111 
variance32_kernel_avx2(const uint8_t * const src,const uint8_t * const ref,__m256i * const sse,__m256i * const sum)112 static INLINE void variance32_kernel_avx2(const uint8_t *const src,
113                                           const uint8_t *const ref,
114                                           __m256i *const sse,
115                                           __m256i *const sum) {
116   const __m256i s = _mm256_loadu_si256((__m256i const *)(src));
117   const __m256i r = _mm256_loadu_si256((__m256i const *)(ref));
118   variance_kernel_avx2(s, r, sse, sum);
119 }
120 
variance16_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)121 static INLINE void variance16_avx2(const uint8_t *src, const int src_stride,
122                                    const uint8_t *ref, const int ref_stride,
123                                    const int h, __m256i *const vsse,
124                                    __m256i *const vsum) {
125   *vsum = _mm256_setzero_si256();
126 
127   for (int i = 0; i < h; i += 2) {
128     variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum);
129     src += 2 * src_stride;
130     ref += 2 * ref_stride;
131   }
132 }
133 
variance32_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)134 static INLINE void variance32_avx2(const uint8_t *src, const int src_stride,
135                                    const uint8_t *ref, const int ref_stride,
136                                    const int h, __m256i *const vsse,
137                                    __m256i *const vsum) {
138   *vsum = _mm256_setzero_si256();
139 
140   for (int i = 0; i < h; i++) {
141     variance32_kernel_avx2(src, ref, vsse, vsum);
142     src += src_stride;
143     ref += ref_stride;
144   }
145 }
146 
variance64_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)147 static INLINE void variance64_avx2(const uint8_t *src, const int src_stride,
148                                    const uint8_t *ref, const int ref_stride,
149                                    const int h, __m256i *const vsse,
150                                    __m256i *const vsum) {
151   *vsum = _mm256_setzero_si256();
152 
153   for (int i = 0; i < h; i++) {
154     variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
155     variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
156     src += src_stride;
157     ref += ref_stride;
158   }
159 }
160 
variance128_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)161 static INLINE void variance128_avx2(const uint8_t *src, const int src_stride,
162                                     const uint8_t *ref, const int ref_stride,
163                                     const int h, __m256i *const vsse,
164                                     __m256i *const vsum) {
165   *vsum = _mm256_setzero_si256();
166 
167   for (int i = 0; i < h; i++) {
168     variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
169     variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
170     variance32_kernel_avx2(src + 64, ref + 64, vsse, vsum);
171     variance32_kernel_avx2(src + 96, ref + 96, vsse, vsum);
172     src += src_stride;
173     ref += ref_stride;
174   }
175 }
176 
177 #define AOM_VAR_NO_LOOP_AVX2(bw, bh, bits, max_pixel)                         \
178   unsigned int aom_variance##bw##x##bh##_avx2(                                \
179       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
180       unsigned int *sse) {                                                    \
181     __m256i vsse = _mm256_setzero_si256();                                    \
182     __m256i vsum;                                                             \
183     variance##bw##_avx2(src, src_stride, ref, ref_stride, bh, &vsse, &vsum);  \
184     const int sum = variance_final_##max_pixel##_avx2(vsse, vsum, sse);       \
185     return *sse - (uint32_t)(((int64_t)sum * sum) >> bits);                   \
186   }
187 
188 AOM_VAR_NO_LOOP_AVX2(16, 4, 6, 512);
189 AOM_VAR_NO_LOOP_AVX2(16, 8, 7, 512);
190 AOM_VAR_NO_LOOP_AVX2(16, 16, 8, 512);
191 AOM_VAR_NO_LOOP_AVX2(16, 32, 9, 512);
192 AOM_VAR_NO_LOOP_AVX2(16, 64, 10, 1024);
193 
194 AOM_VAR_NO_LOOP_AVX2(32, 8, 8, 512);
195 AOM_VAR_NO_LOOP_AVX2(32, 16, 9, 512);
196 AOM_VAR_NO_LOOP_AVX2(32, 32, 10, 1024);
197 AOM_VAR_NO_LOOP_AVX2(32, 64, 11, 2048);
198 
199 AOM_VAR_NO_LOOP_AVX2(64, 16, 10, 1024);
200 AOM_VAR_NO_LOOP_AVX2(64, 32, 11, 2048);
201 
202 #define AOM_VAR_LOOP_AVX2(bw, bh, bits, uh)                                   \
203   unsigned int aom_variance##bw##x##bh##_avx2(                                \
204       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
205       unsigned int *sse) {                                                    \
206     __m256i vsse = _mm256_setzero_si256();                                    \
207     __m256i vsum = _mm256_setzero_si256();                                    \
208     for (int i = 0; i < (bh / uh); i++) {                                     \
209       __m256i vsum16;                                                         \
210       variance##bw##_avx2(src, src_stride, ref, ref_stride, uh, &vsse,        \
211                           &vsum16);                                           \
212       vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16));               \
213       src += uh * src_stride;                                                 \
214       ref += uh * ref_stride;                                                 \
215     }                                                                         \
216     const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);                     \
217     const int sum = variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse);  \
218     return *sse - (unsigned int)(((int64_t)sum * sum) >> bits);               \
219   }
220 
221 AOM_VAR_LOOP_AVX2(64, 64, 12, 32);    // 64x32 * ( 64/32)
222 AOM_VAR_LOOP_AVX2(64, 128, 13, 32);   // 64x32 * (128/32)
223 AOM_VAR_LOOP_AVX2(128, 64, 13, 16);   // 128x16 * ( 64/16)
224 AOM_VAR_LOOP_AVX2(128, 128, 14, 16);  // 128x16 * (128/16)
225 
aom_mse16x16_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,unsigned int * sse)226 unsigned int aom_mse16x16_avx2(const uint8_t *src, int src_stride,
227                                const uint8_t *ref, int ref_stride,
228                                unsigned int *sse) {
229   aom_variance16x16_avx2(src, src_stride, ref, ref_stride, sse);
230   return *sse;
231 }
232 
233 unsigned int aom_sub_pixel_variance32xh_avx2(const uint8_t *src, int src_stride,
234                                              int x_offset, int y_offset,
235                                              const uint8_t *dst, int dst_stride,
236                                              int height, unsigned int *sse);
237 
238 unsigned int aom_sub_pixel_avg_variance32xh_avx2(
239     const uint8_t *src, int src_stride, int x_offset, int y_offset,
240     const uint8_t *dst, int dst_stride, const uint8_t *sec, int sec_stride,
241     int height, unsigned int *sseptr);
242 
243 #define AOM_SUB_PIXEL_VAR_AVX2(w, h, wf, wlog2, hlog2)                        \
244   unsigned int aom_sub_pixel_variance##w##x##h##_avx2(                        \
245       const uint8_t *src, int src_stride, int x_offset, int y_offset,         \
246       const uint8_t *dst, int dst_stride, unsigned int *sse_ptr) {            \
247     /*Avoid overflow in helper by capping height.*/                           \
248     const int hf = AOMMIN(h, 64);                                             \
249     unsigned int sse = 0;                                                     \
250     int se = 0;                                                               \
251     for (int i = 0; i < (w / wf); ++i) {                                      \
252       const uint8_t *src_ptr = src;                                           \
253       const uint8_t *dst_ptr = dst;                                           \
254       for (int j = 0; j < (h / hf); ++j) {                                    \
255         unsigned int sse2;                                                    \
256         const int se2 = aom_sub_pixel_variance##wf##xh_avx2(                  \
257             src_ptr, src_stride, x_offset, y_offset, dst_ptr, dst_stride, hf, \
258             &sse2);                                                           \
259         dst_ptr += hf * dst_stride;                                           \
260         src_ptr += hf * src_stride;                                           \
261         se += se2;                                                            \
262         sse += sse2;                                                          \
263       }                                                                       \
264       src += wf;                                                              \
265       dst += wf;                                                              \
266     }                                                                         \
267     *sse_ptr = sse;                                                           \
268     return sse - (unsigned int)(((int64_t)se * se) >> (wlog2 + hlog2));       \
269   }
270 
271 AOM_SUB_PIXEL_VAR_AVX2(128, 128, 32, 7, 7);
272 AOM_SUB_PIXEL_VAR_AVX2(128, 64, 32, 7, 6);
273 AOM_SUB_PIXEL_VAR_AVX2(64, 128, 32, 6, 7);
274 AOM_SUB_PIXEL_VAR_AVX2(64, 64, 32, 6, 6);
275 AOM_SUB_PIXEL_VAR_AVX2(64, 32, 32, 6, 5);
276 AOM_SUB_PIXEL_VAR_AVX2(32, 64, 32, 5, 6);
277 AOM_SUB_PIXEL_VAR_AVX2(32, 32, 32, 5, 5);
278 AOM_SUB_PIXEL_VAR_AVX2(32, 16, 32, 5, 4);
279 
280 #define AOM_SUB_PIXEL_AVG_VAR_AVX2(w, h, wf, wlog2, hlog2)                \
281   unsigned int aom_sub_pixel_avg_variance##w##x##h##_avx2(                \
282       const uint8_t *src, int src_stride, int x_offset, int y_offset,     \
283       const uint8_t *dst, int dst_stride, unsigned int *sse_ptr,          \
284       const uint8_t *sec) {                                               \
285     /*Avoid overflow in helper by capping height.*/                       \
286     const int hf = AOMMIN(h, 64);                                         \
287     unsigned int sse = 0;                                                 \
288     int se = 0;                                                           \
289     for (int i = 0; i < (w / wf); ++i) {                                  \
290       const uint8_t *src_ptr = src;                                       \
291       const uint8_t *dst_ptr = dst;                                       \
292       const uint8_t *sec_ptr = sec;                                       \
293       for (int j = 0; j < (h / hf); ++j) {                                \
294         unsigned int sse2;                                                \
295         const int se2 = aom_sub_pixel_avg_variance##wf##xh_avx2(          \
296             src_ptr, src_stride, x_offset, y_offset, dst_ptr, dst_stride, \
297             sec_ptr, w, hf, &sse2);                                       \
298         dst_ptr += hf * dst_stride;                                       \
299         src_ptr += hf * src_stride;                                       \
300         sec_ptr += hf * w;                                                \
301         se += se2;                                                        \
302         sse += sse2;                                                      \
303       }                                                                   \
304       src += wf;                                                          \
305       dst += wf;                                                          \
306       sec += wf;                                                          \
307     }                                                                     \
308     *sse_ptr = sse;                                                       \
309     return sse - (unsigned int)(((int64_t)se * se) >> (wlog2 + hlog2));   \
310   }
311 
312 AOM_SUB_PIXEL_AVG_VAR_AVX2(128, 128, 32, 7, 7);
313 AOM_SUB_PIXEL_AVG_VAR_AVX2(128, 64, 32, 7, 6);
314 AOM_SUB_PIXEL_AVG_VAR_AVX2(64, 128, 32, 6, 7);
315 AOM_SUB_PIXEL_AVG_VAR_AVX2(64, 64, 32, 6, 6);
316 AOM_SUB_PIXEL_AVG_VAR_AVX2(64, 32, 32, 6, 5);
317 AOM_SUB_PIXEL_AVG_VAR_AVX2(32, 64, 32, 5, 6);
318 AOM_SUB_PIXEL_AVG_VAR_AVX2(32, 32, 32, 5, 5);
319 AOM_SUB_PIXEL_AVG_VAR_AVX2(32, 16, 32, 5, 4);
320 
mm256_loadu2(const uint8_t * p0,const uint8_t * p1)321 static INLINE __m256i mm256_loadu2(const uint8_t *p0, const uint8_t *p1) {
322   const __m256i d =
323       _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
324   return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
325 }
326 
mm256_loadu2_16(const uint16_t * p0,const uint16_t * p1)327 static INLINE __m256i mm256_loadu2_16(const uint16_t *p0, const uint16_t *p1) {
328   const __m256i d =
329       _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
330   return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
331 }
332 
comp_mask_pred_line_avx2(const __m256i s0,const __m256i s1,const __m256i a,uint8_t * comp_pred)333 static INLINE void comp_mask_pred_line_avx2(const __m256i s0, const __m256i s1,
334                                             const __m256i a,
335                                             uint8_t *comp_pred) {
336   const __m256i alpha_max = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
337   const int16_t round_bits = 15 - AOM_BLEND_A64_ROUND_BITS;
338   const __m256i round_offset = _mm256_set1_epi16(1 << (round_bits));
339 
340   const __m256i ma = _mm256_sub_epi8(alpha_max, a);
341 
342   const __m256i ssAL = _mm256_unpacklo_epi8(s0, s1);
343   const __m256i aaAL = _mm256_unpacklo_epi8(a, ma);
344   const __m256i ssAH = _mm256_unpackhi_epi8(s0, s1);
345   const __m256i aaAH = _mm256_unpackhi_epi8(a, ma);
346 
347   const __m256i blendAL = _mm256_maddubs_epi16(ssAL, aaAL);
348   const __m256i blendAH = _mm256_maddubs_epi16(ssAH, aaAH);
349   const __m256i roundAL = _mm256_mulhrs_epi16(blendAL, round_offset);
350   const __m256i roundAH = _mm256_mulhrs_epi16(blendAH, round_offset);
351 
352   const __m256i roundA = _mm256_packus_epi16(roundAL, roundAH);
353   _mm256_storeu_si256((__m256i *)(comp_pred), roundA);
354 }
355 
aom_comp_mask_pred_avx2(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)356 void aom_comp_mask_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
357                              int height, const uint8_t *ref, int ref_stride,
358                              const uint8_t *mask, int mask_stride,
359                              int invert_mask) {
360   int i = 0;
361   const uint8_t *src0 = invert_mask ? pred : ref;
362   const uint8_t *src1 = invert_mask ? ref : pred;
363   const int stride0 = invert_mask ? width : ref_stride;
364   const int stride1 = invert_mask ? ref_stride : width;
365   if (width == 8) {
366     comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
367                            mask, mask_stride);
368   } else if (width == 16) {
369     do {
370       const __m256i sA0 = mm256_loadu2(src0 + stride0, src0);
371       const __m256i sA1 = mm256_loadu2(src1 + stride1, src1);
372       const __m256i aA = mm256_loadu2(mask + mask_stride, mask);
373       src0 += (stride0 << 1);
374       src1 += (stride1 << 1);
375       mask += (mask_stride << 1);
376       const __m256i sB0 = mm256_loadu2(src0 + stride0, src0);
377       const __m256i sB1 = mm256_loadu2(src1 + stride1, src1);
378       const __m256i aB = mm256_loadu2(mask + mask_stride, mask);
379       src0 += (stride0 << 1);
380       src1 += (stride1 << 1);
381       mask += (mask_stride << 1);
382       // comp_pred's stride == width == 16
383       comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
384       comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
385       comp_pred += (16 << 2);
386       i += 4;
387     } while (i < height);
388   } else {  // for width == 32
389     do {
390       const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0));
391       const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1));
392       const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask));
393 
394       const __m256i sB0 = _mm256_lddqu_si256((const __m256i *)(src0 + stride0));
395       const __m256i sB1 = _mm256_lddqu_si256((const __m256i *)(src1 + stride1));
396       const __m256i aB =
397           _mm256_lddqu_si256((const __m256i *)(mask + mask_stride));
398 
399       comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
400       comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
401       comp_pred += (32 << 1);
402 
403       src0 += (stride0 << 1);
404       src1 += (stride1 << 1);
405       mask += (mask_stride << 1);
406       i += 2;
407     } while (i < height);
408   }
409 }
410 
highbd_comp_mask_pred_line_avx2(const __m256i s0,const __m256i s1,const __m256i a)411 static INLINE __m256i highbd_comp_mask_pred_line_avx2(const __m256i s0,
412                                                       const __m256i s1,
413                                                       const __m256i a) {
414   const __m256i alpha_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
415   const __m256i round_const =
416       _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
417   const __m256i a_inv = _mm256_sub_epi16(alpha_max, a);
418 
419   const __m256i s_lo = _mm256_unpacklo_epi16(s0, s1);
420   const __m256i a_lo = _mm256_unpacklo_epi16(a, a_inv);
421   const __m256i pred_lo = _mm256_madd_epi16(s_lo, a_lo);
422   const __m256i pred_l = _mm256_srai_epi32(
423       _mm256_add_epi32(pred_lo, round_const), AOM_BLEND_A64_ROUND_BITS);
424 
425   const __m256i s_hi = _mm256_unpackhi_epi16(s0, s1);
426   const __m256i a_hi = _mm256_unpackhi_epi16(a, a_inv);
427   const __m256i pred_hi = _mm256_madd_epi16(s_hi, a_hi);
428   const __m256i pred_h = _mm256_srai_epi32(
429       _mm256_add_epi32(pred_hi, round_const), AOM_BLEND_A64_ROUND_BITS);
430 
431   const __m256i comp = _mm256_packs_epi32(pred_l, pred_h);
432 
433   return comp;
434 }
435 
aom_highbd_comp_mask_pred_avx2(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)436 void aom_highbd_comp_mask_pred_avx2(uint8_t *comp_pred8, const uint8_t *pred8,
437                                     int width, int height, const uint8_t *ref8,
438                                     int ref_stride, const uint8_t *mask,
439                                     int mask_stride, int invert_mask) {
440   int i = 0;
441   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
442   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
443   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
444   const uint16_t *src0 = invert_mask ? pred : ref;
445   const uint16_t *src1 = invert_mask ? ref : pred;
446   const int stride0 = invert_mask ? width : ref_stride;
447   const int stride1 = invert_mask ? ref_stride : width;
448   const __m256i zero = _mm256_setzero_si256();
449 
450   if (width == 8) {
451     do {
452       const __m256i s0 = mm256_loadu2_16(src0 + stride0, src0);
453       const __m256i s1 = mm256_loadu2_16(src1 + stride1, src1);
454 
455       const __m128i m_l = _mm_loadl_epi64((const __m128i *)mask);
456       const __m128i m_h = _mm_loadl_epi64((const __m128i *)(mask + 8));
457 
458       __m256i m = _mm256_castsi128_si256(m_l);
459       m = _mm256_insertf128_si256(m, m_h, 1);
460       const __m256i m_16 = _mm256_unpacklo_epi8(m, zero);
461 
462       const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
463 
464       _mm_storeu_si128((__m128i *)(comp_pred), _mm256_castsi256_si128(comp));
465 
466       _mm_storeu_si128((__m128i *)(comp_pred + width),
467                        _mm256_extractf128_si256(comp, 1));
468 
469       src0 += (stride0 << 1);
470       src1 += (stride1 << 1);
471       mask += (mask_stride << 1);
472       comp_pred += (width << 1);
473       i += 2;
474     } while (i < height);
475   } else if (width == 16) {
476     do {
477       const __m256i s0 = _mm256_loadu_si256((const __m256i *)(src0));
478       const __m256i s1 = _mm256_loadu_si256((const __m256i *)(src1));
479       const __m256i m_16 =
480           _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)mask));
481 
482       const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
483 
484       _mm256_storeu_si256((__m256i *)comp_pred, comp);
485 
486       src0 += stride0;
487       src1 += stride1;
488       mask += mask_stride;
489       comp_pred += width;
490       i += 1;
491     } while (i < height);
492   } else if (width == 32) {
493     do {
494       const __m256i s0 = _mm256_loadu_si256((const __m256i *)src0);
495       const __m256i s2 = _mm256_loadu_si256((const __m256i *)(src0 + 16));
496       const __m256i s1 = _mm256_loadu_si256((const __m256i *)src1);
497       const __m256i s3 = _mm256_loadu_si256((const __m256i *)(src1 + 16));
498 
499       const __m256i m01_16 =
500           _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)mask));
501       const __m256i m23_16 =
502           _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(mask + 16)));
503 
504       const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m01_16);
505       const __m256i comp1 = highbd_comp_mask_pred_line_avx2(s2, s3, m23_16);
506 
507       _mm256_storeu_si256((__m256i *)comp_pred, comp);
508       _mm256_storeu_si256((__m256i *)(comp_pred + 16), comp1);
509 
510       src0 += stride0;
511       src1 += stride1;
512       mask += mask_stride;
513       comp_pred += width;
514       i += 1;
515     } while (i < height);
516   }
517 }
518