1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <smmintrin.h>  // SSE4.1
13 #include <immintrin.h>  // AVX2
14 
15 #include <assert.h>
16 
17 #include "aom/aom_integer.h"
18 #include "aom_ports/mem.h"
19 #include "aom_dsp/aom_dsp_common.h"
20 
21 #include "aom_dsp/x86/synonyms.h"
22 #include "aom_dsp/x86/synonyms_avx2.h"
23 #include "aom_dsp/x86/blend_sse4.h"
24 #include "aom_dsp/x86/blend_mask_sse4.h"
25 
26 #include "config/aom_dsp_rtcd.h"
27 
blend_a64_d16_mask_w16_avx2(uint8_t * dst,const CONV_BUF_TYPE * src0,const CONV_BUF_TYPE * src1,const __m256i * m0,const __m256i * v_round_offset,const __m256i * v_maxval,int shift)28 static INLINE void blend_a64_d16_mask_w16_avx2(
29     uint8_t *dst, const CONV_BUF_TYPE *src0, const CONV_BUF_TYPE *src1,
30     const __m256i *m0, const __m256i *v_round_offset, const __m256i *v_maxval,
31     int shift) {
32   const __m256i max_minus_m0 = _mm256_sub_epi16(*v_maxval, *m0);
33   const __m256i s0_0 = yy_loadu_256(src0);
34   const __m256i s1_0 = yy_loadu_256(src1);
35   __m256i res0_lo = _mm256_madd_epi16(_mm256_unpacklo_epi16(s0_0, s1_0),
36                                       _mm256_unpacklo_epi16(*m0, max_minus_m0));
37   __m256i res0_hi = _mm256_madd_epi16(_mm256_unpackhi_epi16(s0_0, s1_0),
38                                       _mm256_unpackhi_epi16(*m0, max_minus_m0));
39   res0_lo =
40       _mm256_srai_epi32(_mm256_sub_epi32(res0_lo, *v_round_offset), shift);
41   res0_hi =
42       _mm256_srai_epi32(_mm256_sub_epi32(res0_hi, *v_round_offset), shift);
43   const __m256i res0 = _mm256_packs_epi32(res0_lo, res0_hi);
44   __m256i res = _mm256_packus_epi16(res0, res0);
45   res = _mm256_permute4x64_epi64(res, 0xd8);
46   _mm_storeu_si128((__m128i *)(dst), _mm256_castsi256_si128(res));
47 }
48 
blend_a64_d16_mask_w32_avx2(uint8_t * dst,const CONV_BUF_TYPE * src0,const CONV_BUF_TYPE * src1,const __m256i * m0,const __m256i * m1,const __m256i * v_round_offset,const __m256i * v_maxval,int shift)49 static INLINE void blend_a64_d16_mask_w32_avx2(
50     uint8_t *dst, const CONV_BUF_TYPE *src0, const CONV_BUF_TYPE *src1,
51     const __m256i *m0, const __m256i *m1, const __m256i *v_round_offset,
52     const __m256i *v_maxval, int shift) {
53   const __m256i max_minus_m0 = _mm256_sub_epi16(*v_maxval, *m0);
54   const __m256i max_minus_m1 = _mm256_sub_epi16(*v_maxval, *m1);
55   const __m256i s0_0 = yy_loadu_256(src0);
56   const __m256i s0_1 = yy_loadu_256(src0 + 16);
57   const __m256i s1_0 = yy_loadu_256(src1);
58   const __m256i s1_1 = yy_loadu_256(src1 + 16);
59   __m256i res0_lo = _mm256_madd_epi16(_mm256_unpacklo_epi16(s0_0, s1_0),
60                                       _mm256_unpacklo_epi16(*m0, max_minus_m0));
61   __m256i res0_hi = _mm256_madd_epi16(_mm256_unpackhi_epi16(s0_0, s1_0),
62                                       _mm256_unpackhi_epi16(*m0, max_minus_m0));
63   __m256i res1_lo = _mm256_madd_epi16(_mm256_unpacklo_epi16(s0_1, s1_1),
64                                       _mm256_unpacklo_epi16(*m1, max_minus_m1));
65   __m256i res1_hi = _mm256_madd_epi16(_mm256_unpackhi_epi16(s0_1, s1_1),
66                                       _mm256_unpackhi_epi16(*m1, max_minus_m1));
67   res0_lo =
68       _mm256_srai_epi32(_mm256_sub_epi32(res0_lo, *v_round_offset), shift);
69   res0_hi =
70       _mm256_srai_epi32(_mm256_sub_epi32(res0_hi, *v_round_offset), shift);
71   res1_lo =
72       _mm256_srai_epi32(_mm256_sub_epi32(res1_lo, *v_round_offset), shift);
73   res1_hi =
74       _mm256_srai_epi32(_mm256_sub_epi32(res1_hi, *v_round_offset), shift);
75   const __m256i res0 = _mm256_packs_epi32(res0_lo, res0_hi);
76   const __m256i res1 = _mm256_packs_epi32(res1_lo, res1_hi);
77   __m256i res = _mm256_packus_epi16(res0, res1);
78   res = _mm256_permute4x64_epi64(res, 0xd8);
79   _mm256_storeu_si256((__m256i *)(dst), res);
80 }
81 
lowbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift)82 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
83     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
84     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
85     const uint8_t *mask, uint32_t mask_stride, int h,
86     const __m256i *round_offset, int shift) {
87   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
88   for (int i = 0; i < h; ++i) {
89     const __m128i m = xx_loadu_128(mask);
90     const __m256i m0 = _mm256_cvtepu8_epi16(m);
91 
92     blend_a64_d16_mask_w16_avx2(dst, src0, src1, &m0, round_offset, &v_maxval,
93                                 shift);
94     mask += mask_stride;
95     dst += dst_stride;
96     src0 += src0_stride;
97     src1 += src1_stride;
98   }
99 }
100 
lowbd_blend_a64_d16_mask_subw0_subh0_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)101 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh0_w32_avx2(
102     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
103     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
104     const uint8_t *mask, uint32_t mask_stride, int h, int w,
105     const __m256i *round_offset, int shift) {
106   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
107   for (int i = 0; i < h; ++i) {
108     for (int j = 0; j < w; j += 32) {
109       const __m256i m = yy_loadu_256(mask + j);
110       const __m256i m0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(m));
111       const __m256i m1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(m, 1));
112 
113       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
114                                   round_offset, &v_maxval, shift);
115     }
116     mask += mask_stride;
117     dst += dst_stride;
118     src0 += src0_stride;
119     src1 += src1_stride;
120   }
121 }
122 
lowbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift)123 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
124     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
125     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
126     const uint8_t *mask, uint32_t mask_stride, int h,
127     const __m256i *round_offset, int shift) {
128   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
129   const __m256i one_b = _mm256_set1_epi8(1);
130   const __m256i two_w = _mm256_set1_epi16(2);
131   for (int i = 0; i < h; ++i) {
132     const __m256i m_i00 = yy_loadu_256(mask);
133     const __m256i m_i10 = yy_loadu_256(mask + mask_stride);
134 
135     const __m256i m0_ac = _mm256_adds_epu8(m_i00, m_i10);
136     const __m256i m0_acbd = _mm256_maddubs_epi16(m0_ac, one_b);
137     const __m256i m0 = _mm256_srli_epi16(_mm256_add_epi16(m0_acbd, two_w), 2);
138 
139     blend_a64_d16_mask_w16_avx2(dst, src0, src1, &m0, round_offset, &v_maxval,
140                                 shift);
141     mask += mask_stride << 1;
142     dst += dst_stride;
143     src0 += src0_stride;
144     src1 += src1_stride;
145   }
146 }
147 
lowbd_blend_a64_d16_mask_subw1_subh1_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)148 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh1_w32_avx2(
149     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
150     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
151     const uint8_t *mask, uint32_t mask_stride, int h, int w,
152     const __m256i *round_offset, int shift) {
153   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
154   const __m256i one_b = _mm256_set1_epi8(1);
155   const __m256i two_w = _mm256_set1_epi16(2);
156   for (int i = 0; i < h; ++i) {
157     for (int j = 0; j < w; j += 32) {
158       const __m256i m_i00 = yy_loadu_256(mask + 2 * j);
159       const __m256i m_i01 = yy_loadu_256(mask + 2 * j + 32);
160       const __m256i m_i10 = yy_loadu_256(mask + mask_stride + 2 * j);
161       const __m256i m_i11 = yy_loadu_256(mask + mask_stride + 2 * j + 32);
162 
163       const __m256i m0_ac = _mm256_adds_epu8(m_i00, m_i10);
164       const __m256i m1_ac = _mm256_adds_epu8(m_i01, m_i11);
165       const __m256i m0_acbd = _mm256_maddubs_epi16(m0_ac, one_b);
166       const __m256i m1_acbd = _mm256_maddubs_epi16(m1_ac, one_b);
167       const __m256i m0 = _mm256_srli_epi16(_mm256_add_epi16(m0_acbd, two_w), 2);
168       const __m256i m1 = _mm256_srli_epi16(_mm256_add_epi16(m1_acbd, two_w), 2);
169 
170       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
171                                   round_offset, &v_maxval, shift);
172     }
173     mask += mask_stride << 1;
174     dst += dst_stride;
175     src0 += src0_stride;
176     src1 += src1_stride;
177   }
178 }
179 
lowbd_blend_a64_d16_mask_subw1_subh0_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)180 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh0_w16_avx2(
181     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
182     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
183     const uint8_t *mask, uint32_t mask_stride, int h, int w,
184     const __m256i *round_offset, int shift) {
185   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
186   const __m256i one_b = _mm256_set1_epi8(1);
187   const __m256i zeros = _mm256_setzero_si256();
188   for (int i = 0; i < h; ++i) {
189     for (int j = 0; j < w; j += 16) {
190       const __m256i m_i00 = yy_loadu_256(mask + 2 * j);
191       const __m256i m0_ac = _mm256_maddubs_epi16(m_i00, one_b);
192       const __m256i m0 = _mm256_avg_epu16(m0_ac, zeros);
193 
194       blend_a64_d16_mask_w16_avx2(dst + j, src0 + j, src1 + j, &m0,
195                                   round_offset, &v_maxval, shift);
196     }
197     mask += mask_stride;
198     dst += dst_stride;
199     src0 += src0_stride;
200     src1 += src1_stride;
201   }
202 }
203 
lowbd_blend_a64_d16_mask_subw1_subh0_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)204 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh0_w32_avx2(
205     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
206     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
207     const uint8_t *mask, uint32_t mask_stride, int h, int w,
208     const __m256i *round_offset, int shift) {
209   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
210   const __m256i one_b = _mm256_set1_epi8(1);
211   const __m256i zeros = _mm256_setzero_si256();
212   for (int i = 0; i < h; ++i) {
213     for (int j = 0; j < w; j += 32) {
214       const __m256i m_i00 = yy_loadu_256(mask + 2 * j);
215       const __m256i m_i01 = yy_loadu_256(mask + 2 * j + 32);
216       const __m256i m0_ac = _mm256_maddubs_epi16(m_i00, one_b);
217       const __m256i m1_ac = _mm256_maddubs_epi16(m_i01, one_b);
218       const __m256i m0 = _mm256_avg_epu16(m0_ac, zeros);
219       const __m256i m1 = _mm256_avg_epu16(m1_ac, zeros);
220 
221       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
222                                   round_offset, &v_maxval, shift);
223     }
224     mask += mask_stride;
225     dst += dst_stride;
226     src0 += src0_stride;
227     src1 += src1_stride;
228   }
229 }
230 
lowbd_blend_a64_d16_mask_subw0_subh1_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)231 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh1_w16_avx2(
232     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
233     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
234     const uint8_t *mask, uint32_t mask_stride, int h, int w,
235     const __m256i *round_offset, int shift) {
236   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
237   const __m128i zeros = _mm_setzero_si128();
238   for (int i = 0; i < h; ++i) {
239     for (int j = 0; j < w; j += 16) {
240       const __m128i m_i00 = xx_loadu_128(mask + j);
241       const __m128i m_i10 = xx_loadu_128(mask + mask_stride + j);
242 
243       const __m128i m_ac = _mm_avg_epu8(_mm_adds_epu8(m_i00, m_i10), zeros);
244       const __m256i m0 = _mm256_cvtepu8_epi16(m_ac);
245 
246       blend_a64_d16_mask_w16_avx2(dst + j, src0 + j, src1 + j, &m0,
247                                   round_offset, &v_maxval, shift);
248     }
249     mask += mask_stride << 1;
250     dst += dst_stride;
251     src0 += src0_stride;
252     src1 += src1_stride;
253   }
254 }
255 
lowbd_blend_a64_d16_mask_subw0_subh1_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)256 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh1_w32_avx2(
257     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
258     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
259     const uint8_t *mask, uint32_t mask_stride, int h, int w,
260     const __m256i *round_offset, int shift) {
261   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
262   const __m256i zeros = _mm256_setzero_si256();
263   for (int i = 0; i < h; ++i) {
264     for (int j = 0; j < w; j += 32) {
265       const __m256i m_i00 = yy_loadu_256(mask + j);
266       const __m256i m_i10 = yy_loadu_256(mask + mask_stride + j);
267 
268       const __m256i m_ac =
269           _mm256_avg_epu8(_mm256_adds_epu8(m_i00, m_i10), zeros);
270       const __m256i m0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(m_ac));
271       const __m256i m1 =
272           _mm256_cvtepu8_epi16(_mm256_extracti128_si256(m_ac, 1));
273 
274       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
275                                   round_offset, &v_maxval, shift);
276     }
277     mask += mask_stride << 1;
278     dst += dst_stride;
279     src0 += src0_stride;
280     src1 += src1_stride;
281   }
282 }
283 
aom_lowbd_blend_a64_d16_mask_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,ConvolveParams * conv_params)284 void aom_lowbd_blend_a64_d16_mask_avx2(
285     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
286     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
287     const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
288     ConvolveParams *conv_params) {
289   const int bd = 8;
290   const int round_bits =
291       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
292 
293   const int round_offset =
294       ((1 << (round_bits + bd)) + (1 << (round_bits + bd - 1)) -
295        (1 << (round_bits - 1)))
296       << AOM_BLEND_A64_ROUND_BITS;
297 
298   const int shift = round_bits + AOM_BLEND_A64_ROUND_BITS;
299   assert(IMPLIES((void *)src0 == dst, src0_stride == dst_stride));
300   assert(IMPLIES((void *)src1 == dst, src1_stride == dst_stride));
301 
302   assert(h >= 4);
303   assert(w >= 4);
304   assert(IS_POWER_OF_TWO(h));
305   assert(IS_POWER_OF_TWO(w));
306   const __m128i v_round_offset = _mm_set1_epi32(round_offset);
307   const __m256i y_round_offset = _mm256_set1_epi32(round_offset);
308 
309   if (subw == 0 && subh == 0) {
310     switch (w) {
311       case 4:
312         aom_lowbd_blend_a64_d16_mask_subw0_subh0_w4_sse4_1(
313             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
314             mask_stride, h, &v_round_offset, shift);
315         break;
316       case 8:
317         aom_lowbd_blend_a64_d16_mask_subw0_subh0_w8_sse4_1(
318             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
319             mask_stride, h, &v_round_offset, shift);
320         break;
321       case 16:
322         lowbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
323             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
324             mask_stride, h, &y_round_offset, shift);
325         break;
326       default:
327         lowbd_blend_a64_d16_mask_subw0_subh0_w32_avx2(
328             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
329             mask_stride, h, w, &y_round_offset, shift);
330         break;
331     }
332   } else if (subw == 1 && subh == 1) {
333     switch (w) {
334       case 4:
335         aom_lowbd_blend_a64_d16_mask_subw1_subh1_w4_sse4_1(
336             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
337             mask_stride, h, &v_round_offset, shift);
338         break;
339       case 8:
340         aom_lowbd_blend_a64_d16_mask_subw1_subh1_w8_sse4_1(
341             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
342             mask_stride, h, &v_round_offset, shift);
343         break;
344       case 16:
345         lowbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
346             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
347             mask_stride, h, &y_round_offset, shift);
348         break;
349       default:
350         lowbd_blend_a64_d16_mask_subw1_subh1_w32_avx2(
351             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
352             mask_stride, h, w, &y_round_offset, shift);
353         break;
354     }
355   } else if (subw == 1 && subh == 0) {
356     switch (w) {
357       case 4:
358         aom_lowbd_blend_a64_d16_mask_subw1_subh0_w4_sse4_1(
359             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
360             mask_stride, h, &v_round_offset, shift);
361         break;
362       case 8:
363         aom_lowbd_blend_a64_d16_mask_subw1_subh0_w8_sse4_1(
364             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
365             mask_stride, h, &v_round_offset, shift);
366         break;
367       case 16:
368         lowbd_blend_a64_d16_mask_subw1_subh0_w16_avx2(
369             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
370             mask_stride, h, w, &y_round_offset, shift);
371         break;
372       default:
373         lowbd_blend_a64_d16_mask_subw1_subh0_w32_avx2(
374             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
375             mask_stride, h, w, &y_round_offset, shift);
376         break;
377     }
378   } else {
379     switch (w) {
380       case 4:
381         aom_lowbd_blend_a64_d16_mask_subw0_subh1_w4_sse4_1(
382             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
383             mask_stride, h, &v_round_offset, shift);
384         break;
385       case 8:
386         aom_lowbd_blend_a64_d16_mask_subw0_subh1_w8_sse4_1(
387             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
388             mask_stride, h, &v_round_offset, shift);
389         break;
390       case 16:
391         lowbd_blend_a64_d16_mask_subw0_subh1_w16_avx2(
392             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
393             mask_stride, h, w, &y_round_offset, shift);
394         break;
395       default:
396         lowbd_blend_a64_d16_mask_subw0_subh1_w32_avx2(
397             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
398             mask_stride, h, w, &y_round_offset, shift);
399         break;
400     }
401   }
402 }
403 
blend_16_u8_avx2(const uint8_t * src0,const uint8_t * src1,const __m256i * v_m0_b,const __m256i * v_m1_b,const int32_t bits)404 static INLINE __m256i blend_16_u8_avx2(const uint8_t *src0, const uint8_t *src1,
405                                        const __m256i *v_m0_b,
406                                        const __m256i *v_m1_b,
407                                        const int32_t bits) {
408   const __m256i v_s0_b = _mm256_castsi128_si256(xx_loadu_128(src0));
409   const __m256i v_s1_b = _mm256_castsi128_si256(xx_loadu_128(src1));
410   const __m256i v_s0_s_b = _mm256_permute4x64_epi64(v_s0_b, 0xd8);
411   const __m256i v_s1_s_b = _mm256_permute4x64_epi64(v_s1_b, 0xd8);
412 
413   const __m256i v_p0_w =
414       _mm256_maddubs_epi16(_mm256_unpacklo_epi8(v_s0_s_b, v_s1_s_b),
415                            _mm256_unpacklo_epi8(*v_m0_b, *v_m1_b));
416 
417   const __m256i v_res0_w = yy_roundn_epu16(v_p0_w, bits);
418   const __m256i v_res_b = _mm256_packus_epi16(v_res0_w, v_res0_w);
419   const __m256i v_res = _mm256_permute4x64_epi64(v_res_b, 0xd8);
420   return v_res;
421 }
422 
blend_32_u8_avx2(const uint8_t * src0,const uint8_t * src1,const __m256i * v_m0_b,const __m256i * v_m1_b,const int32_t bits)423 static INLINE __m256i blend_32_u8_avx2(const uint8_t *src0, const uint8_t *src1,
424                                        const __m256i *v_m0_b,
425                                        const __m256i *v_m1_b,
426                                        const int32_t bits) {
427   const __m256i v_s0_b = yy_loadu_256(src0);
428   const __m256i v_s1_b = yy_loadu_256(src1);
429 
430   const __m256i v_p0_w =
431       _mm256_maddubs_epi16(_mm256_unpacklo_epi8(v_s0_b, v_s1_b),
432                            _mm256_unpacklo_epi8(*v_m0_b, *v_m1_b));
433   const __m256i v_p1_w =
434       _mm256_maddubs_epi16(_mm256_unpackhi_epi8(v_s0_b, v_s1_b),
435                            _mm256_unpackhi_epi8(*v_m0_b, *v_m1_b));
436 
437   const __m256i v_res0_w = yy_roundn_epu16(v_p0_w, bits);
438   const __m256i v_res1_w = yy_roundn_epu16(v_p1_w, bits);
439   const __m256i v_res = _mm256_packus_epi16(v_res0_w, v_res1_w);
440   return v_res;
441 }
442 
blend_a64_mask_sx_sy_w16_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h)443 static INLINE void blend_a64_mask_sx_sy_w16_avx2(
444     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
445     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
446     const uint8_t *mask, uint32_t mask_stride, int h) {
447   const __m256i v_zmask_b = _mm256_set1_epi16(0xFF);
448   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
449   do {
450     const __m256i v_ral_b = yy_loadu_256(mask);
451     const __m256i v_rbl_b = yy_loadu_256(mask + mask_stride);
452     const __m256i v_rvsl_b = _mm256_add_epi8(v_ral_b, v_rbl_b);
453     const __m256i v_rvsal_w = _mm256_and_si256(v_rvsl_b, v_zmask_b);
454     const __m256i v_rvsbl_w =
455         _mm256_and_si256(_mm256_srli_si256(v_rvsl_b, 1), v_zmask_b);
456     const __m256i v_rsl_w = _mm256_add_epi16(v_rvsal_w, v_rvsbl_w);
457 
458     const __m256i v_m0_w = yy_roundn_epu16(v_rsl_w, 2);
459     const __m256i v_m0_b = _mm256_packus_epi16(v_m0_w, v_m0_w);
460     const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
461 
462     const __m256i y_res_b = blend_16_u8_avx2(src0, src1, &v_m0_b, &v_m1_b,
463                                              AOM_BLEND_A64_ROUND_BITS);
464 
465     xx_storeu_128(dst, _mm256_castsi256_si128(y_res_b));
466     dst += dst_stride;
467     src0 += src0_stride;
468     src1 += src1_stride;
469     mask += 2 * mask_stride;
470   } while (--h);
471 }
472 
blend_a64_mask_sx_sy_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)473 static INLINE void blend_a64_mask_sx_sy_w32n_avx2(
474     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
475     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
476     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
477   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
478   const __m256i v_zmask_b = _mm256_set1_epi16(0xFF);
479   do {
480     int c;
481     for (c = 0; c < w; c += 32) {
482       const __m256i v_ral_b = yy_loadu_256(mask + 2 * c);
483       const __m256i v_rah_b = yy_loadu_256(mask + 2 * c + 32);
484       const __m256i v_rbl_b = yy_loadu_256(mask + mask_stride + 2 * c);
485       const __m256i v_rbh_b = yy_loadu_256(mask + mask_stride + 2 * c + 32);
486       const __m256i v_rvsl_b = _mm256_add_epi8(v_ral_b, v_rbl_b);
487       const __m256i v_rvsh_b = _mm256_add_epi8(v_rah_b, v_rbh_b);
488       const __m256i v_rvsal_w = _mm256_and_si256(v_rvsl_b, v_zmask_b);
489       const __m256i v_rvsah_w = _mm256_and_si256(v_rvsh_b, v_zmask_b);
490       const __m256i v_rvsbl_w =
491           _mm256_and_si256(_mm256_srli_si256(v_rvsl_b, 1), v_zmask_b);
492       const __m256i v_rvsbh_w =
493           _mm256_and_si256(_mm256_srli_si256(v_rvsh_b, 1), v_zmask_b);
494       const __m256i v_rsl_w = _mm256_add_epi16(v_rvsal_w, v_rvsbl_w);
495       const __m256i v_rsh_w = _mm256_add_epi16(v_rvsah_w, v_rvsbh_w);
496 
497       const __m256i v_m0l_w = yy_roundn_epu16(v_rsl_w, 2);
498       const __m256i v_m0h_w = yy_roundn_epu16(v_rsh_w, 2);
499       const __m256i v_m0_b =
500           _mm256_permute4x64_epi64(_mm256_packus_epi16(v_m0l_w, v_m0h_w), 0xd8);
501       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
502 
503       const __m256i v_res_b = blend_32_u8_avx2(
504           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
505 
506       yy_storeu_256(dst + c, v_res_b);
507     }
508     dst += dst_stride;
509     src0 += src0_stride;
510     src1 += src1_stride;
511     mask += 2 * mask_stride;
512   } while (--h);
513 }
514 
blend_a64_mask_sx_sy_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)515 static INLINE void blend_a64_mask_sx_sy_avx2(
516     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
517     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
518     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
519   const __m128i v_shuffle_b = xx_loadu_128(g_blend_a64_mask_shuffle);
520   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
521   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
522   switch (w) {
523     case 4:
524       do {
525         const __m128i v_ra_b = xx_loadl_64(mask);
526         const __m128i v_rb_b = xx_loadl_64(mask + mask_stride);
527         const __m128i v_rvs_b = _mm_add_epi8(v_ra_b, v_rb_b);
528         const __m128i v_r_s_b = _mm_shuffle_epi8(v_rvs_b, v_shuffle_b);
529         const __m128i v_r0_s_w = _mm_cvtepu8_epi16(v_r_s_b);
530         const __m128i v_r1_s_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_r_s_b, 8));
531         const __m128i v_rs_w = _mm_add_epi16(v_r0_s_w, v_r1_s_w);
532         const __m128i v_m0_w = xx_roundn_epu16(v_rs_w, 2);
533         const __m128i v_m0_b = _mm_packus_epi16(v_m0_w, v_m0_w);
534         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
535 
536         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
537 
538         xx_storel_32(dst, v_res_b);
539 
540         dst += dst_stride;
541         src0 += src0_stride;
542         src1 += src1_stride;
543         mask += 2 * mask_stride;
544       } while (--h);
545       break;
546     case 8:
547       do {
548         const __m128i v_ra_b = xx_loadu_128(mask);
549         const __m128i v_rb_b = xx_loadu_128(mask + mask_stride);
550         const __m128i v_rvs_b = _mm_add_epi8(v_ra_b, v_rb_b);
551         const __m128i v_r_s_b = _mm_shuffle_epi8(v_rvs_b, v_shuffle_b);
552         const __m128i v_r0_s_w = _mm_cvtepu8_epi16(v_r_s_b);
553         const __m128i v_r1_s_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_r_s_b, 8));
554         const __m128i v_rs_w = _mm_add_epi16(v_r0_s_w, v_r1_s_w);
555         const __m128i v_m0_w = xx_roundn_epu16(v_rs_w, 2);
556         const __m128i v_m0_b = _mm_packus_epi16(v_m0_w, v_m0_w);
557         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
558 
559         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
560 
561         xx_storel_64(dst, v_res_b);
562 
563         dst += dst_stride;
564         src0 += src0_stride;
565         src1 += src1_stride;
566         mask += 2 * mask_stride;
567       } while (--h);
568       break;
569     case 16:
570       blend_a64_mask_sx_sy_w16_avx2(dst, dst_stride, src0, src0_stride, src1,
571                                     src1_stride, mask, mask_stride, h);
572       break;
573     default:
574       blend_a64_mask_sx_sy_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
575                                      src1_stride, mask, mask_stride, w, h);
576       break;
577   }
578 }
579 
blend_a64_mask_sx_w16_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h)580 static INLINE void blend_a64_mask_sx_w16_avx2(
581     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
582     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
583     const uint8_t *mask, uint32_t mask_stride, int h) {
584   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
585   const __m256i v_zmask_b = _mm256_set1_epi16(0xff);
586   do {
587     const __m256i v_rl_b = yy_loadu_256(mask);
588     const __m256i v_al_b =
589         _mm256_avg_epu8(v_rl_b, _mm256_srli_si256(v_rl_b, 1));
590 
591     const __m256i v_m0_w = _mm256_and_si256(v_al_b, v_zmask_b);
592     const __m256i v_m0_b = _mm256_packus_epi16(v_m0_w, _mm256_setzero_si256());
593     const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
594 
595     const __m256i v_res_b = blend_16_u8_avx2(src0, src1, &v_m0_b, &v_m1_b,
596                                              AOM_BLEND_A64_ROUND_BITS);
597 
598     xx_storeu_128(dst, _mm256_castsi256_si128(v_res_b));
599     dst += dst_stride;
600     src0 += src0_stride;
601     src1 += src1_stride;
602     mask += mask_stride;
603   } while (--h);
604 }
605 
blend_a64_mask_sx_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)606 static INLINE void blend_a64_mask_sx_w32n_avx2(
607     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
608     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
609     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
610   const __m256i v_shuffle_b = yy_loadu_256(g_blend_a64_mask_shuffle);
611   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
612   do {
613     int c;
614     for (c = 0; c < w; c += 32) {
615       const __m256i v_r0_b = yy_loadu_256(mask + 2 * c);
616       const __m256i v_r1_b = yy_loadu_256(mask + 2 * c + 32);
617       const __m256i v_r0_s_b = _mm256_shuffle_epi8(v_r0_b, v_shuffle_b);
618       const __m256i v_r1_s_b = _mm256_shuffle_epi8(v_r1_b, v_shuffle_b);
619       const __m256i v_al_b =
620           _mm256_avg_epu8(v_r0_s_b, _mm256_srli_si256(v_r0_s_b, 8));
621       const __m256i v_ah_b =
622           _mm256_avg_epu8(v_r1_s_b, _mm256_srli_si256(v_r1_s_b, 8));
623 
624       const __m256i v_m0_b =
625           _mm256_permute4x64_epi64(_mm256_unpacklo_epi64(v_al_b, v_ah_b), 0xd8);
626       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
627 
628       const __m256i v_res_b = blend_32_u8_avx2(
629           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
630 
631       yy_storeu_256(dst + c, v_res_b);
632     }
633     dst += dst_stride;
634     src0 += src0_stride;
635     src1 += src1_stride;
636     mask += mask_stride;
637   } while (--h);
638 }
639 
blend_a64_mask_sx_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)640 static INLINE void blend_a64_mask_sx_avx2(
641     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
642     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
643     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
644   const __m128i v_shuffle_b = xx_loadu_128(g_blend_a64_mask_shuffle);
645   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
646   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
647   switch (w) {
648     case 4:
649       do {
650         const __m128i v_r_b = xx_loadl_64(mask);
651         const __m128i v_r0_s_b = _mm_shuffle_epi8(v_r_b, v_shuffle_b);
652         const __m128i v_r_lo_b = _mm_unpacklo_epi64(v_r0_s_b, v_r0_s_b);
653         const __m128i v_r_hi_b = _mm_unpackhi_epi64(v_r0_s_b, v_r0_s_b);
654         const __m128i v_m0_b = _mm_avg_epu8(v_r_lo_b, v_r_hi_b);
655         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
656 
657         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
658 
659         xx_storel_32(dst, v_res_b);
660 
661         dst += dst_stride;
662         src0 += src0_stride;
663         src1 += src1_stride;
664         mask += mask_stride;
665       } while (--h);
666       break;
667     case 8:
668       do {
669         const __m128i v_r_b = xx_loadu_128(mask);
670         const __m128i v_r0_s_b = _mm_shuffle_epi8(v_r_b, v_shuffle_b);
671         const __m128i v_r_lo_b = _mm_unpacklo_epi64(v_r0_s_b, v_r0_s_b);
672         const __m128i v_r_hi_b = _mm_unpackhi_epi64(v_r0_s_b, v_r0_s_b);
673         const __m128i v_m0_b = _mm_avg_epu8(v_r_lo_b, v_r_hi_b);
674         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
675 
676         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
677 
678         xx_storel_64(dst, v_res_b);
679 
680         dst += dst_stride;
681         src0 += src0_stride;
682         src1 += src1_stride;
683         mask += mask_stride;
684       } while (--h);
685       break;
686     case 16:
687       blend_a64_mask_sx_w16_avx2(dst, dst_stride, src0, src0_stride, src1,
688                                  src1_stride, mask, mask_stride, h);
689       break;
690     default:
691       blend_a64_mask_sx_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
692                                   src1_stride, mask, mask_stride, w, h);
693       break;
694   }
695 }
696 
blend_a64_mask_sy_w16_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h)697 static INLINE void blend_a64_mask_sy_w16_avx2(
698     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
699     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
700     const uint8_t *mask, uint32_t mask_stride, int h) {
701   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
702   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
703   do {
704     const __m128i v_ra_b = xx_loadu_128(mask);
705     const __m128i v_rb_b = xx_loadu_128(mask + mask_stride);
706     const __m128i v_m0_b = _mm_avg_epu8(v_ra_b, v_rb_b);
707 
708     const __m128i v_m1_b = _mm_sub_epi16(v_maxval_b, v_m0_b);
709     const __m128i v_res_b = blend_16_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
710 
711     xx_storeu_128(dst, v_res_b);
712     dst += dst_stride;
713     src0 += src0_stride;
714     src1 += src1_stride;
715     mask += 2 * mask_stride;
716   } while (--h);
717 }
718 
blend_a64_mask_sy_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)719 static INLINE void blend_a64_mask_sy_w32n_avx2(
720     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
721     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
722     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
723   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
724   do {
725     int c;
726     for (c = 0; c < w; c += 32) {
727       const __m256i v_ra_b = yy_loadu_256(mask + c);
728       const __m256i v_rb_b = yy_loadu_256(mask + c + mask_stride);
729       const __m256i v_m0_b = _mm256_avg_epu8(v_ra_b, v_rb_b);
730       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
731       const __m256i v_res_b = blend_32_u8_avx2(
732           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
733 
734       yy_storeu_256(dst + c, v_res_b);
735     }
736     dst += dst_stride;
737     src0 += src0_stride;
738     src1 += src1_stride;
739     mask += 2 * mask_stride;
740   } while (--h);
741 }
742 
blend_a64_mask_sy_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)743 static INLINE void blend_a64_mask_sy_avx2(
744     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
745     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
746     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
747   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
748   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
749   switch (w) {
750     case 4:
751       do {
752         const __m128i v_ra_b = xx_loadl_32(mask);
753         const __m128i v_rb_b = xx_loadl_32(mask + mask_stride);
754         const __m128i v_m0_b = _mm_avg_epu8(v_ra_b, v_rb_b);
755         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
756         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
757 
758         xx_storel_32(dst, v_res_b);
759 
760         dst += dst_stride;
761         src0 += src0_stride;
762         src1 += src1_stride;
763         mask += 2 * mask_stride;
764       } while (--h);
765       break;
766     case 8:
767       do {
768         const __m128i v_ra_b = xx_loadl_64(mask);
769         const __m128i v_rb_b = xx_loadl_64(mask + mask_stride);
770         const __m128i v_m0_b = _mm_avg_epu8(v_ra_b, v_rb_b);
771         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
772         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
773 
774         xx_storel_64(dst, v_res_b);
775 
776         dst += dst_stride;
777         src0 += src0_stride;
778         src1 += src1_stride;
779         mask += 2 * mask_stride;
780       } while (--h);
781       break;
782     case 16:
783       blend_a64_mask_sy_w16_avx2(dst, dst_stride, src0, src0_stride, src1,
784                                  src1_stride, mask, mask_stride, h);
785       break;
786     default:
787       blend_a64_mask_sy_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
788                                   src1_stride, mask, mask_stride, w, h);
789   }
790 }
791 
blend_a64_mask_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)792 static INLINE void blend_a64_mask_w32n_avx2(
793     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
794     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
795     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
796   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
797   do {
798     int c;
799     for (c = 0; c < w; c += 32) {
800       const __m256i v_m0_b = yy_loadu_256(mask + c);
801       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
802 
803       const __m256i v_res_b = blend_32_u8_avx2(
804           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
805 
806       yy_storeu_256(dst + c, v_res_b);
807     }
808     dst += dst_stride;
809     src0 += src0_stride;
810     src1 += src1_stride;
811     mask += mask_stride;
812   } while (--h);
813 }
814 
blend_a64_mask_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)815 static INLINE void blend_a64_mask_avx2(
816     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
817     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
818     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
819   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
820   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
821   switch (w) {
822     case 4:
823       do {
824         const __m128i v_m0_b = xx_loadl_32(mask);
825         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
826         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
827 
828         xx_storel_32(dst, v_res_b);
829 
830         dst += dst_stride;
831         src0 += src0_stride;
832         src1 += src1_stride;
833         mask += mask_stride;
834       } while (--h);
835       break;
836     case 8:
837       do {
838         const __m128i v_m0_b = xx_loadl_64(mask);
839         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
840         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
841 
842         xx_storel_64(dst, v_res_b);
843 
844         dst += dst_stride;
845         src0 += src0_stride;
846         src1 += src1_stride;
847         mask += mask_stride;
848       } while (--h);
849       break;
850     case 16:
851       do {
852         const __m128i v_m0_b = xx_loadu_128(mask);
853         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
854         const __m128i v_res_b = blend_16_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
855 
856         xx_storeu_128(dst, v_res_b);
857         dst += dst_stride;
858         src0 += src0_stride;
859         src1 += src1_stride;
860         mask += mask_stride;
861       } while (--h);
862       break;
863     default:
864       blend_a64_mask_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
865                                src1_stride, mask, mask_stride, w, h);
866   }
867 }
868 
aom_blend_a64_mask_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subx,int suby)869 void aom_blend_a64_mask_avx2(uint8_t *dst, uint32_t dst_stride,
870                              const uint8_t *src0, uint32_t src0_stride,
871                              const uint8_t *src1, uint32_t src1_stride,
872                              const uint8_t *mask, uint32_t mask_stride, int w,
873                              int h, int subx, int suby) {
874   assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
875   assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
876 
877   assert(h >= 1);
878   assert(w >= 1);
879   assert(IS_POWER_OF_TWO(h));
880   assert(IS_POWER_OF_TWO(w));
881 
882   if (UNLIKELY((h | w) & 3)) {  // if (w <= 2 || h <= 2)
883     aom_blend_a64_mask_c(dst, dst_stride, src0, src0_stride, src1, src1_stride,
884                          mask, mask_stride, w, h, subx, suby);
885   } else {
886     if (subx & suby) {
887       blend_a64_mask_sx_sy_avx2(dst, dst_stride, src0, src0_stride, src1,
888                                 src1_stride, mask, mask_stride, w, h);
889     } else if (subx) {
890       blend_a64_mask_sx_avx2(dst, dst_stride, src0, src0_stride, src1,
891                              src1_stride, mask, mask_stride, w, h);
892     } else if (suby) {
893       blend_a64_mask_sy_avx2(dst, dst_stride, src0, src0_stride, src1,
894                              src1_stride, mask, mask_stride, w, h);
895     } else {
896       blend_a64_mask_avx2(dst, dst_stride, src0, src0_stride, src1, src1_stride,
897                           mask, mask_stride, w, h);
898     }
899   }
900 }
901 
902 //////////////////////////////////////////////////////////////////////////////
903 // aom_highbd_blend_a64_d16_mask_avx2()
904 //////////////////////////////////////////////////////////////////////////////
905 
highbd_blend_a64_d16_mask_w4_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const __m256i * mask0,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)906 static INLINE void highbd_blend_a64_d16_mask_w4_avx2(
907     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
908     const CONV_BUF_TYPE *src1, int src1_stride, const __m256i *mask0,
909     const __m256i *round_offset, int shift, const __m256i *clip_low,
910     const __m256i *clip_high, const __m256i *mask_max) {
911   // Load 4x u16 pixels from each of 4 rows from each source
912   const __m256i s0 = _mm256_set_epi64x(*(uint64_t *)(src0 + 3 * src0_stride),
913                                        *(uint64_t *)(src0 + 2 * src0_stride),
914                                        *(uint64_t *)(src0 + 1 * src0_stride),
915                                        *(uint64_t *)(src0 + 0 * src0_stride));
916   const __m256i s1 = _mm256_set_epi64x(*(uint64_t *)(src1 + 3 * src1_stride),
917                                        *(uint64_t *)(src1 + 2 * src1_stride),
918                                        *(uint64_t *)(src1 + 1 * src1_stride),
919                                        *(uint64_t *)(src1 + 0 * src1_stride));
920   // Generate the inverse mask
921   const __m256i mask1 = _mm256_sub_epi16(*mask_max, *mask0);
922 
923   // Multiply each mask by the respective source
924   const __m256i mul0_highs = _mm256_mulhi_epu16(*mask0, s0);
925   const __m256i mul0_lows = _mm256_mullo_epi16(*mask0, s0);
926   const __m256i mul0h = _mm256_unpackhi_epi16(mul0_lows, mul0_highs);
927   const __m256i mul0l = _mm256_unpacklo_epi16(mul0_lows, mul0_highs);
928   // Note that AVX2 unpack orders 64-bit words as [3 1] [2 0] to keep within
929   // lanes Later, packs does the same again which cancels this out with no need
930   // for a permute.  The intermediate values being reordered makes no difference
931 
932   const __m256i mul1_highs = _mm256_mulhi_epu16(mask1, s1);
933   const __m256i mul1_lows = _mm256_mullo_epi16(mask1, s1);
934   const __m256i mul1h = _mm256_unpackhi_epi16(mul1_lows, mul1_highs);
935   const __m256i mul1l = _mm256_unpacklo_epi16(mul1_lows, mul1_highs);
936 
937   const __m256i sumh = _mm256_add_epi32(mul0h, mul1h);
938   const __m256i suml = _mm256_add_epi32(mul0l, mul1l);
939 
940   const __m256i roundh =
941       _mm256_srai_epi32(_mm256_sub_epi32(sumh, *round_offset), shift);
942   const __m256i roundl =
943       _mm256_srai_epi32(_mm256_sub_epi32(suml, *round_offset), shift);
944 
945   const __m256i pack = _mm256_packs_epi32(roundl, roundh);
946   const __m256i clip =
947       _mm256_min_epi16(_mm256_max_epi16(pack, *clip_low), *clip_high);
948 
949   // _mm256_extract_epi64 doesn't exist on x86, so do it the old-fashioned way:
950   const __m128i cliph = _mm256_extracti128_si256(clip, 1);
951   xx_storel_64(dst + 3 * dst_stride, _mm_srli_si128(cliph, 8));
952   xx_storel_64(dst + 2 * dst_stride, cliph);
953   const __m128i clipl = _mm256_castsi256_si128(clip);
954   xx_storel_64(dst + 1 * dst_stride, _mm_srli_si128(clipl, 8));
955   xx_storel_64(dst + 0 * dst_stride, clipl);
956 }
957 
highbd_blend_a64_d16_mask_subw0_subh0_w4_avx2(uint16_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)958 static INLINE void highbd_blend_a64_d16_mask_subw0_subh0_w4_avx2(
959     uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
960     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
961     const uint8_t *mask, uint32_t mask_stride, int h,
962     const __m256i *round_offset, int shift, const __m256i *clip_low,
963     const __m256i *clip_high, const __m256i *mask_max) {
964   do {
965     // Load 8x u8 pixels from each of 4 rows of the mask, pad each to u16
966     const __m128i mask08 = _mm_set_epi32(*(uint32_t *)(mask + 3 * mask_stride),
967                                          *(uint32_t *)(mask + 2 * mask_stride),
968                                          *(uint32_t *)(mask + 1 * mask_stride),
969                                          *(uint32_t *)(mask + 0 * mask_stride));
970     const __m256i mask0 = _mm256_cvtepu8_epi16(mask08);
971 
972     highbd_blend_a64_d16_mask_w4_avx2(dst, dst_stride, src0, src0_stride, src1,
973                                       src1_stride, &mask0, round_offset, shift,
974                                       clip_low, clip_high, mask_max);
975 
976     dst += dst_stride * 4;
977     src0 += src0_stride * 4;
978     src1 += src1_stride * 4;
979     mask += mask_stride * 4;
980   } while (h -= 4);
981 }
982 
highbd_blend_a64_d16_mask_subw1_subh1_w4_avx2(uint16_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)983 static INLINE void highbd_blend_a64_d16_mask_subw1_subh1_w4_avx2(
984     uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
985     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
986     const uint8_t *mask, uint32_t mask_stride, int h,
987     const __m256i *round_offset, int shift, const __m256i *clip_low,
988     const __m256i *clip_high, const __m256i *mask_max) {
989   const __m256i one_b = _mm256_set1_epi8(1);
990   const __m256i two_w = _mm256_set1_epi16(2);
991   do {
992     // Load 8 pixels from each of 8 rows of mask,
993     // (saturating) add together rows then use madd to add adjacent pixels
994     // Finally, divide each value by 4 (with rounding)
995     const __m256i m0246 =
996         _mm256_set_epi64x(*(uint64_t *)(mask + 6 * mask_stride),
997                           *(uint64_t *)(mask + 4 * mask_stride),
998                           *(uint64_t *)(mask + 2 * mask_stride),
999                           *(uint64_t *)(mask + 0 * mask_stride));
1000     const __m256i m1357 =
1001         _mm256_set_epi64x(*(uint64_t *)(mask + 7 * mask_stride),
1002                           *(uint64_t *)(mask + 5 * mask_stride),
1003                           *(uint64_t *)(mask + 3 * mask_stride),
1004                           *(uint64_t *)(mask + 1 * mask_stride));
1005     const __m256i addrows = _mm256_adds_epu8(m0246, m1357);
1006     const __m256i adjacent = _mm256_maddubs_epi16(addrows, one_b);
1007     const __m256i mask0 =
1008         _mm256_srli_epi16(_mm256_add_epi16(adjacent, two_w), 2);
1009 
1010     highbd_blend_a64_d16_mask_w4_avx2(dst, dst_stride, src0, src0_stride, src1,
1011                                       src1_stride, &mask0, round_offset, shift,
1012                                       clip_low, clip_high, mask_max);
1013 
1014     dst += dst_stride * 4;
1015     src0 += src0_stride * 4;
1016     src1 += src1_stride * 4;
1017     mask += mask_stride * 8;
1018   } while (h -= 4);
1019 }
1020 
highbd_blend_a64_d16_mask_w8_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const __m256i * mask0a,const __m256i * mask0b,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1021 static INLINE void highbd_blend_a64_d16_mask_w8_avx2(
1022     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1023     const CONV_BUF_TYPE *src1, int src1_stride, const __m256i *mask0a,
1024     const __m256i *mask0b, const __m256i *round_offset, int shift,
1025     const __m256i *clip_low, const __m256i *clip_high,
1026     const __m256i *mask_max) {
1027   // Load 8x u16 pixels from each of 4 rows from each source
1028   const __m256i s0a =
1029       yy_loadu2_128(src0 + 0 * src0_stride, src0 + 1 * src0_stride);
1030   const __m256i s0b =
1031       yy_loadu2_128(src0 + 2 * src0_stride, src0 + 3 * src0_stride);
1032   const __m256i s1a =
1033       yy_loadu2_128(src1 + 0 * src1_stride, src1 + 1 * src1_stride);
1034   const __m256i s1b =
1035       yy_loadu2_128(src1 + 2 * src1_stride, src1 + 3 * src1_stride);
1036 
1037   // Generate inverse masks
1038   const __m256i mask1a = _mm256_sub_epi16(*mask_max, *mask0a);
1039   const __m256i mask1b = _mm256_sub_epi16(*mask_max, *mask0b);
1040 
1041   // Multiply sources by respective masks
1042   const __m256i mul0a_highs = _mm256_mulhi_epu16(*mask0a, s0a);
1043   const __m256i mul0a_lows = _mm256_mullo_epi16(*mask0a, s0a);
1044   const __m256i mul0ah = _mm256_unpackhi_epi16(mul0a_lows, mul0a_highs);
1045   const __m256i mul0al = _mm256_unpacklo_epi16(mul0a_lows, mul0a_highs);
1046   // Note that AVX2 unpack orders 64-bit words as [3 1] [2 0] to keep within
1047   // lanes Later, packs does the same again which cancels this out with no need
1048   // for a permute.  The intermediate values being reordered makes no difference
1049 
1050   const __m256i mul1a_highs = _mm256_mulhi_epu16(mask1a, s1a);
1051   const __m256i mul1a_lows = _mm256_mullo_epi16(mask1a, s1a);
1052   const __m256i mul1ah = _mm256_unpackhi_epi16(mul1a_lows, mul1a_highs);
1053   const __m256i mul1al = _mm256_unpacklo_epi16(mul1a_lows, mul1a_highs);
1054 
1055   const __m256i sumah = _mm256_add_epi32(mul0ah, mul1ah);
1056   const __m256i sumal = _mm256_add_epi32(mul0al, mul1al);
1057 
1058   const __m256i mul0b_highs = _mm256_mulhi_epu16(*mask0b, s0b);
1059   const __m256i mul0b_lows = _mm256_mullo_epi16(*mask0b, s0b);
1060   const __m256i mul0bh = _mm256_unpackhi_epi16(mul0b_lows, mul0b_highs);
1061   const __m256i mul0bl = _mm256_unpacklo_epi16(mul0b_lows, mul0b_highs);
1062 
1063   const __m256i mul1b_highs = _mm256_mulhi_epu16(mask1b, s1b);
1064   const __m256i mul1b_lows = _mm256_mullo_epi16(mask1b, s1b);
1065   const __m256i mul1bh = _mm256_unpackhi_epi16(mul1b_lows, mul1b_highs);
1066   const __m256i mul1bl = _mm256_unpacklo_epi16(mul1b_lows, mul1b_highs);
1067 
1068   const __m256i sumbh = _mm256_add_epi32(mul0bh, mul1bh);
1069   const __m256i sumbl = _mm256_add_epi32(mul0bl, mul1bl);
1070 
1071   // Divide down each result, with rounding
1072   const __m256i roundah =
1073       _mm256_srai_epi32(_mm256_sub_epi32(sumah, *round_offset), shift);
1074   const __m256i roundal =
1075       _mm256_srai_epi32(_mm256_sub_epi32(sumal, *round_offset), shift);
1076   const __m256i roundbh =
1077       _mm256_srai_epi32(_mm256_sub_epi32(sumbh, *round_offset), shift);
1078   const __m256i roundbl =
1079       _mm256_srai_epi32(_mm256_sub_epi32(sumbl, *round_offset), shift);
1080 
1081   // Pack each i32 down to an i16 with saturation, then clip to valid range
1082   const __m256i packa = _mm256_packs_epi32(roundal, roundah);
1083   const __m256i clipa =
1084       _mm256_min_epi16(_mm256_max_epi16(packa, *clip_low), *clip_high);
1085   const __m256i packb = _mm256_packs_epi32(roundbl, roundbh);
1086   const __m256i clipb =
1087       _mm256_min_epi16(_mm256_max_epi16(packb, *clip_low), *clip_high);
1088 
1089   // Store 8x u16 pixels to each of 4 rows in the destination
1090   yy_storeu2_128(dst + 0 * dst_stride, dst + 1 * dst_stride, clipa);
1091   yy_storeu2_128(dst + 2 * dst_stride, dst + 3 * dst_stride, clipb);
1092 }
1093 
highbd_blend_a64_d16_mask_subw0_subh0_w8_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1094 static INLINE void highbd_blend_a64_d16_mask_subw0_subh0_w8_avx2(
1095     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1096     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1097     int mask_stride, int h, const __m256i *round_offset, int shift,
1098     const __m256i *clip_low, const __m256i *clip_high,
1099     const __m256i *mask_max) {
1100   do {
1101     // Load 8x u8 pixels from each of 4 rows in the mask
1102     const __m128i mask0a8 =
1103         _mm_set_epi64x(*(uint64_t *)mask, *(uint64_t *)(mask + mask_stride));
1104     const __m128i mask0b8 =
1105         _mm_set_epi64x(*(uint64_t *)(mask + 2 * mask_stride),
1106                        *(uint64_t *)(mask + 3 * mask_stride));
1107     const __m256i mask0a = _mm256_cvtepu8_epi16(mask0a8);
1108     const __m256i mask0b = _mm256_cvtepu8_epi16(mask0b8);
1109 
1110     highbd_blend_a64_d16_mask_w8_avx2(
1111         dst, dst_stride, src0, src0_stride, src1, src1_stride, &mask0a, &mask0b,
1112         round_offset, shift, clip_low, clip_high, mask_max);
1113 
1114     dst += dst_stride * 4;
1115     src0 += src0_stride * 4;
1116     src1 += src1_stride * 4;
1117     mask += mask_stride * 4;
1118   } while (h -= 4);
1119 }
1120 
highbd_blend_a64_d16_mask_subw1_subh1_w8_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1121 static INLINE void highbd_blend_a64_d16_mask_subw1_subh1_w8_avx2(
1122     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1123     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1124     int mask_stride, int h, const __m256i *round_offset, int shift,
1125     const __m256i *clip_low, const __m256i *clip_high,
1126     const __m256i *mask_max) {
1127   const __m256i one_b = _mm256_set1_epi8(1);
1128   const __m256i two_w = _mm256_set1_epi16(2);
1129   do {
1130     // Load 16x u8 pixels from each of 8 rows in the mask,
1131     // (saturating) add together rows then use madd to add adjacent pixels
1132     // Finally, divide each value by 4 (with rounding)
1133     const __m256i m02 =
1134         yy_loadu2_128(mask + 0 * mask_stride, mask + 2 * mask_stride);
1135     const __m256i m13 =
1136         yy_loadu2_128(mask + 1 * mask_stride, mask + 3 * mask_stride);
1137     const __m256i m0123 =
1138         _mm256_maddubs_epi16(_mm256_adds_epu8(m02, m13), one_b);
1139     const __m256i mask_0a =
1140         _mm256_srli_epi16(_mm256_add_epi16(m0123, two_w), 2);
1141     const __m256i m46 =
1142         yy_loadu2_128(mask + 4 * mask_stride, mask + 6 * mask_stride);
1143     const __m256i m57 =
1144         yy_loadu2_128(mask + 5 * mask_stride, mask + 7 * mask_stride);
1145     const __m256i m4567 =
1146         _mm256_maddubs_epi16(_mm256_adds_epu8(m46, m57), one_b);
1147     const __m256i mask_0b =
1148         _mm256_srli_epi16(_mm256_add_epi16(m4567, two_w), 2);
1149 
1150     highbd_blend_a64_d16_mask_w8_avx2(
1151         dst, dst_stride, src0, src0_stride, src1, src1_stride, &mask_0a,
1152         &mask_0b, round_offset, shift, clip_low, clip_high, mask_max);
1153 
1154     dst += dst_stride * 4;
1155     src0 += src0_stride * 4;
1156     src1 += src1_stride * 4;
1157     mask += mask_stride * 8;
1158   } while (h -= 4);
1159 }
1160 
highbd_blend_a64_d16_mask_w16_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const __m256i * mask0a,const __m256i * mask0b,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1161 static INLINE void highbd_blend_a64_d16_mask_w16_avx2(
1162     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1163     const CONV_BUF_TYPE *src1, int src1_stride, const __m256i *mask0a,
1164     const __m256i *mask0b, const __m256i *round_offset, int shift,
1165     const __m256i *clip_low, const __m256i *clip_high,
1166     const __m256i *mask_max) {
1167   // Load 16x pixels from each of 2 rows from each source
1168   const __m256i s0a = yy_loadu_256(src0);
1169   const __m256i s0b = yy_loadu_256(src0 + src0_stride);
1170   const __m256i s1a = yy_loadu_256(src1);
1171   const __m256i s1b = yy_loadu_256(src1 + src1_stride);
1172 
1173   // Calculate inverse masks
1174   const __m256i mask1a = _mm256_sub_epi16(*mask_max, *mask0a);
1175   const __m256i mask1b = _mm256_sub_epi16(*mask_max, *mask0b);
1176 
1177   // Multiply each source by appropriate mask
1178   const __m256i mul0a_highs = _mm256_mulhi_epu16(*mask0a, s0a);
1179   const __m256i mul0a_lows = _mm256_mullo_epi16(*mask0a, s0a);
1180   const __m256i mul0ah = _mm256_unpackhi_epi16(mul0a_lows, mul0a_highs);
1181   const __m256i mul0al = _mm256_unpacklo_epi16(mul0a_lows, mul0a_highs);
1182   // Note that AVX2 unpack orders 64-bit words as [3 1] [2 0] to keep within
1183   // lanes Later, packs does the same again which cancels this out with no need
1184   // for a permute.  The intermediate values being reordered makes no difference
1185 
1186   const __m256i mul1a_highs = _mm256_mulhi_epu16(mask1a, s1a);
1187   const __m256i mul1a_lows = _mm256_mullo_epi16(mask1a, s1a);
1188   const __m256i mul1ah = _mm256_unpackhi_epi16(mul1a_lows, mul1a_highs);
1189   const __m256i mul1al = _mm256_unpacklo_epi16(mul1a_lows, mul1a_highs);
1190 
1191   const __m256i mulah = _mm256_add_epi32(mul0ah, mul1ah);
1192   const __m256i mulal = _mm256_add_epi32(mul0al, mul1al);
1193 
1194   const __m256i mul0b_highs = _mm256_mulhi_epu16(*mask0b, s0b);
1195   const __m256i mul0b_lows = _mm256_mullo_epi16(*mask0b, s0b);
1196   const __m256i mul0bh = _mm256_unpackhi_epi16(mul0b_lows, mul0b_highs);
1197   const __m256i mul0bl = _mm256_unpacklo_epi16(mul0b_lows, mul0b_highs);
1198 
1199   const __m256i mul1b_highs = _mm256_mulhi_epu16(mask1b, s1b);
1200   const __m256i mul1b_lows = _mm256_mullo_epi16(mask1b, s1b);
1201   const __m256i mul1bh = _mm256_unpackhi_epi16(mul1b_lows, mul1b_highs);
1202   const __m256i mul1bl = _mm256_unpacklo_epi16(mul1b_lows, mul1b_highs);
1203 
1204   const __m256i mulbh = _mm256_add_epi32(mul0bh, mul1bh);
1205   const __m256i mulbl = _mm256_add_epi32(mul0bl, mul1bl);
1206 
1207   const __m256i resah =
1208       _mm256_srai_epi32(_mm256_sub_epi32(mulah, *round_offset), shift);
1209   const __m256i resal =
1210       _mm256_srai_epi32(_mm256_sub_epi32(mulal, *round_offset), shift);
1211   const __m256i resbh =
1212       _mm256_srai_epi32(_mm256_sub_epi32(mulbh, *round_offset), shift);
1213   const __m256i resbl =
1214       _mm256_srai_epi32(_mm256_sub_epi32(mulbl, *round_offset), shift);
1215 
1216   // Signed saturating pack from i32 to i16:
1217   const __m256i packa = _mm256_packs_epi32(resal, resah);
1218   const __m256i packb = _mm256_packs_epi32(resbl, resbh);
1219 
1220   // Clip the values to the valid range
1221   const __m256i clipa =
1222       _mm256_min_epi16(_mm256_max_epi16(packa, *clip_low), *clip_high);
1223   const __m256i clipb =
1224       _mm256_min_epi16(_mm256_max_epi16(packb, *clip_low), *clip_high);
1225 
1226   // Store 16 pixels
1227   yy_storeu_256(dst, clipa);
1228   yy_storeu_256(dst + dst_stride, clipb);
1229 }
1230 
highbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,int w,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1231 static INLINE void highbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
1232     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1233     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1234     int mask_stride, int h, int w, const __m256i *round_offset, int shift,
1235     const __m256i *clip_low, const __m256i *clip_high,
1236     const __m256i *mask_max) {
1237   for (int i = 0; i < h; i += 2) {
1238     for (int j = 0; j < w; j += 16) {
1239       // Load 16x u8 alpha-mask values from each of two rows and pad to u16
1240       const __m128i masks_a8 = xx_loadu_128(mask + j);
1241       const __m128i masks_b8 = xx_loadu_128(mask + mask_stride + j);
1242       const __m256i mask0a = _mm256_cvtepu8_epi16(masks_a8);
1243       const __m256i mask0b = _mm256_cvtepu8_epi16(masks_b8);
1244 
1245       highbd_blend_a64_d16_mask_w16_avx2(
1246           dst + j, dst_stride, src0 + j, src0_stride, src1 + j, src1_stride,
1247           &mask0a, &mask0b, round_offset, shift, clip_low, clip_high, mask_max);
1248     }
1249     dst += dst_stride * 2;
1250     src0 += src0_stride * 2;
1251     src1 += src1_stride * 2;
1252     mask += mask_stride * 2;
1253   }
1254 }
1255 
highbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,int w,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1256 static INLINE void highbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
1257     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1258     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1259     int mask_stride, int h, int w, const __m256i *round_offset, int shift,
1260     const __m256i *clip_low, const __m256i *clip_high,
1261     const __m256i *mask_max) {
1262   const __m256i one_b = _mm256_set1_epi8(1);
1263   const __m256i two_w = _mm256_set1_epi16(2);
1264   for (int i = 0; i < h; i += 2) {
1265     for (int j = 0; j < w; j += 16) {
1266       // Load 32x u8 alpha-mask values from each of four rows
1267       // (saturating) add pairs of rows, then use madd to add adjacent values
1268       // Finally, divide down each result with rounding
1269       const __m256i m0 = yy_loadu_256(mask + 0 * mask_stride + 2 * j);
1270       const __m256i m1 = yy_loadu_256(mask + 1 * mask_stride + 2 * j);
1271       const __m256i m2 = yy_loadu_256(mask + 2 * mask_stride + 2 * j);
1272       const __m256i m3 = yy_loadu_256(mask + 3 * mask_stride + 2 * j);
1273 
1274       const __m256i m01_8 = _mm256_adds_epu8(m0, m1);
1275       const __m256i m23_8 = _mm256_adds_epu8(m2, m3);
1276 
1277       const __m256i m01 = _mm256_maddubs_epi16(m01_8, one_b);
1278       const __m256i m23 = _mm256_maddubs_epi16(m23_8, one_b);
1279 
1280       const __m256i mask0a = _mm256_srli_epi16(_mm256_add_epi16(m01, two_w), 2);
1281       const __m256i mask0b = _mm256_srli_epi16(_mm256_add_epi16(m23, two_w), 2);
1282 
1283       highbd_blend_a64_d16_mask_w16_avx2(
1284           dst + j, dst_stride, src0 + j, src0_stride, src1 + j, src1_stride,
1285           &mask0a, &mask0b, round_offset, shift, clip_low, clip_high, mask_max);
1286     }
1287     dst += dst_stride * 2;
1288     src0 += src0_stride * 2;
1289     src1 += src1_stride * 2;
1290     mask += mask_stride * 4;
1291   }
1292 }
1293 
aom_highbd_blend_a64_d16_mask_avx2(uint8_t * dst8,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,ConvolveParams * conv_params,const int bd)1294 void aom_highbd_blend_a64_d16_mask_avx2(
1295     uint8_t *dst8, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
1296     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
1297     const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
1298     ConvolveParams *conv_params, const int bd) {
1299   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
1300   const int round_bits =
1301       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
1302   const int32_t round_offset =
1303       ((1 << (round_bits + bd)) + (1 << (round_bits + bd - 1)) -
1304        (1 << (round_bits - 1)))
1305       << AOM_BLEND_A64_ROUND_BITS;
1306   const __m256i v_round_offset = _mm256_set1_epi32(round_offset);
1307   const int shift = round_bits + AOM_BLEND_A64_ROUND_BITS;
1308 
1309   const __m256i clip_low = _mm256_set1_epi16(0);
1310   const __m256i clip_high = _mm256_set1_epi16((1 << bd) - 1);
1311   const __m256i mask_max = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
1312 
1313   assert(IMPLIES((void *)src0 == dst, src0_stride == dst_stride));
1314   assert(IMPLIES((void *)src1 == dst, src1_stride == dst_stride));
1315 
1316   assert(h >= 4);
1317   assert(w >= 4);
1318   assert(IS_POWER_OF_TWO(h));
1319   assert(IS_POWER_OF_TWO(w));
1320 
1321   if (subw == 0 && subh == 0) {
1322     switch (w) {
1323       case 4:
1324         highbd_blend_a64_d16_mask_subw0_subh0_w4_avx2(
1325             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1326             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1327             &mask_max);
1328         break;
1329       case 8:
1330         highbd_blend_a64_d16_mask_subw0_subh0_w8_avx2(
1331             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1332             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1333             &mask_max);
1334         break;
1335       default:  // >= 16
1336         highbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
1337             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1338             mask_stride, h, w, &v_round_offset, shift, &clip_low, &clip_high,
1339             &mask_max);
1340         break;
1341     }
1342 
1343   } else if (subw == 1 && subh == 1) {
1344     switch (w) {
1345       case 4:
1346         highbd_blend_a64_d16_mask_subw1_subh1_w4_avx2(
1347             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1348             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1349             &mask_max);
1350         break;
1351       case 8:
1352         highbd_blend_a64_d16_mask_subw1_subh1_w8_avx2(
1353             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1354             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1355             &mask_max);
1356         break;
1357       default:  // >= 16
1358         highbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
1359             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1360             mask_stride, h, w, &v_round_offset, shift, &clip_low, &clip_high,
1361             &mask_max);
1362         break;
1363     }
1364   } else {
1365     // Sub-sampling in only one axis doesn't seem to happen very much, so fall
1366     // back to the vanilla C implementation instead of having all the optimised
1367     // code for these.
1368     aom_highbd_blend_a64_d16_mask_c(dst8, dst_stride, src0, src0_stride, src1,
1369                                     src1_stride, mask, mask_stride, w, h, subw,
1370                                     subh, conv_params, bd);
1371   }
1372 }
1373