1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <immintrin.h>
13
14 #include "config/aom_dsp_rtcd.h"
15
16 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
17
mm256_add_hi_lo_epi16(const __m256i val)18 static INLINE __m128i mm256_add_hi_lo_epi16(const __m256i val) {
19 return _mm_add_epi16(_mm256_castsi256_si128(val),
20 _mm256_extractf128_si256(val, 1));
21 }
22
mm256_add_hi_lo_epi32(const __m256i val)23 static INLINE __m128i mm256_add_hi_lo_epi32(const __m256i val) {
24 return _mm_add_epi32(_mm256_castsi256_si128(val),
25 _mm256_extractf128_si256(val, 1));
26 }
27
variance_kernel_avx2(const __m256i src,const __m256i ref,__m256i * const sse,__m256i * const sum)28 static INLINE void variance_kernel_avx2(const __m256i src, const __m256i ref,
29 __m256i *const sse,
30 __m256i *const sum) {
31 const __m256i adj_sub = _mm256_set1_epi16(0xff01); // (1,-1)
32
33 // unpack into pairs of source and reference values
34 const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
35 const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
36
37 // subtract adjacent elements using src*1 + ref*-1
38 const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
39 const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
40 const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
41 const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
42
43 // add to the running totals
44 *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1));
45 *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
46 }
47
variance_final_from_32bit_sum_avx2(__m256i vsse,__m128i vsum,unsigned int * const sse)48 static INLINE int variance_final_from_32bit_sum_avx2(__m256i vsse, __m128i vsum,
49 unsigned int *const sse) {
50 // extract the low lane and add it to the high lane
51 const __m128i sse_reg_128 = mm256_add_hi_lo_epi32(vsse);
52
53 // unpack sse and sum registers and add
54 const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum);
55 const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum);
56 const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
57
58 // perform the final summation and extract the results
59 const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
60 *((int *)sse) = _mm_cvtsi128_si32(res);
61 return _mm_extract_epi32(res, 1);
62 }
63
64 // handle pixels (<= 512)
variance_final_512_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)65 static INLINE int variance_final_512_avx2(__m256i vsse, __m256i vsum,
66 unsigned int *const sse) {
67 // extract the low lane and add it to the high lane
68 const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
69 const __m128i vsum_64 = _mm_add_epi16(vsum_128, _mm_srli_si128(vsum_128, 8));
70 const __m128i sum_int32 = _mm_cvtepi16_epi32(vsum_64);
71 return variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse);
72 }
73
74 // handle 1024 pixels (32x32, 16x64, 64x16)
variance_final_1024_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)75 static INLINE int variance_final_1024_avx2(__m256i vsse, __m256i vsum,
76 unsigned int *const sse) {
77 // extract the low lane and add it to the high lane
78 const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
79 const __m128i vsum_64 =
80 _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128),
81 _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8)));
82 return variance_final_from_32bit_sum_avx2(vsse, vsum_64, sse);
83 }
84
sum_to_32bit_avx2(const __m256i sum)85 static INLINE __m256i sum_to_32bit_avx2(const __m256i sum) {
86 const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum));
87 const __m256i sum_hi =
88 _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1));
89 return _mm256_add_epi32(sum_lo, sum_hi);
90 }
91
92 // handle 2048 pixels (32x64, 64x32)
variance_final_2048_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)93 static INLINE int variance_final_2048_avx2(__m256i vsse, __m256i vsum,
94 unsigned int *const sse) {
95 vsum = sum_to_32bit_avx2(vsum);
96 const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);
97 return variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse);
98 }
99
variance16_kernel_avx2(const uint8_t * const src,const int src_stride,const uint8_t * const ref,const int ref_stride,__m256i * const sse,__m256i * const sum)100 static INLINE void variance16_kernel_avx2(
101 const uint8_t *const src, const int src_stride, const uint8_t *const ref,
102 const int ref_stride, __m256i *const sse, __m256i *const sum) {
103 const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
104 const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
105 const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
106 const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
107 const __m256i s = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
108 const __m256i r = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
109 variance_kernel_avx2(s, r, sse, sum);
110 }
111
variance32_kernel_avx2(const uint8_t * const src,const uint8_t * const ref,__m256i * const sse,__m256i * const sum)112 static INLINE void variance32_kernel_avx2(const uint8_t *const src,
113 const uint8_t *const ref,
114 __m256i *const sse,
115 __m256i *const sum) {
116 const __m256i s = _mm256_loadu_si256((__m256i const *)(src));
117 const __m256i r = _mm256_loadu_si256((__m256i const *)(ref));
118 variance_kernel_avx2(s, r, sse, sum);
119 }
120
variance16_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)121 static INLINE void variance16_avx2(const uint8_t *src, const int src_stride,
122 const uint8_t *ref, const int ref_stride,
123 const int h, __m256i *const vsse,
124 __m256i *const vsum) {
125 *vsum = _mm256_setzero_si256();
126
127 for (int i = 0; i < h; i += 2) {
128 variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum);
129 src += 2 * src_stride;
130 ref += 2 * ref_stride;
131 }
132 }
133
variance32_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)134 static INLINE void variance32_avx2(const uint8_t *src, const int src_stride,
135 const uint8_t *ref, const int ref_stride,
136 const int h, __m256i *const vsse,
137 __m256i *const vsum) {
138 *vsum = _mm256_setzero_si256();
139
140 for (int i = 0; i < h; i++) {
141 variance32_kernel_avx2(src, ref, vsse, vsum);
142 src += src_stride;
143 ref += ref_stride;
144 }
145 }
146
variance64_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)147 static INLINE void variance64_avx2(const uint8_t *src, const int src_stride,
148 const uint8_t *ref, const int ref_stride,
149 const int h, __m256i *const vsse,
150 __m256i *const vsum) {
151 *vsum = _mm256_setzero_si256();
152
153 for (int i = 0; i < h; i++) {
154 variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
155 variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
156 src += src_stride;
157 ref += ref_stride;
158 }
159 }
160
variance128_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)161 static INLINE void variance128_avx2(const uint8_t *src, const int src_stride,
162 const uint8_t *ref, const int ref_stride,
163 const int h, __m256i *const vsse,
164 __m256i *const vsum) {
165 *vsum = _mm256_setzero_si256();
166
167 for (int i = 0; i < h; i++) {
168 variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
169 variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
170 variance32_kernel_avx2(src + 64, ref + 64, vsse, vsum);
171 variance32_kernel_avx2(src + 96, ref + 96, vsse, vsum);
172 src += src_stride;
173 ref += ref_stride;
174 }
175 }
176
177 #define AOM_VAR_NO_LOOP_AVX2(bw, bh, bits, max_pixel) \
178 unsigned int aom_variance##bw##x##bh##_avx2( \
179 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
180 unsigned int *sse) { \
181 __m256i vsse = _mm256_setzero_si256(); \
182 __m256i vsum; \
183 variance##bw##_avx2(src, src_stride, ref, ref_stride, bh, &vsse, &vsum); \
184 const int sum = variance_final_##max_pixel##_avx2(vsse, vsum, sse); \
185 return *sse - (uint32_t)(((int64_t)sum * sum) >> bits); \
186 }
187
188 AOM_VAR_NO_LOOP_AVX2(16, 4, 6, 512);
189 AOM_VAR_NO_LOOP_AVX2(16, 8, 7, 512);
190 AOM_VAR_NO_LOOP_AVX2(16, 16, 8, 512);
191 AOM_VAR_NO_LOOP_AVX2(16, 32, 9, 512);
192 AOM_VAR_NO_LOOP_AVX2(16, 64, 10, 1024);
193
194 AOM_VAR_NO_LOOP_AVX2(32, 8, 8, 512);
195 AOM_VAR_NO_LOOP_AVX2(32, 16, 9, 512);
196 AOM_VAR_NO_LOOP_AVX2(32, 32, 10, 1024);
197 AOM_VAR_NO_LOOP_AVX2(32, 64, 11, 2048);
198
199 AOM_VAR_NO_LOOP_AVX2(64, 16, 10, 1024);
200 AOM_VAR_NO_LOOP_AVX2(64, 32, 11, 2048);
201
202 #define AOM_VAR_LOOP_AVX2(bw, bh, bits, uh) \
203 unsigned int aom_variance##bw##x##bh##_avx2( \
204 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
205 unsigned int *sse) { \
206 __m256i vsse = _mm256_setzero_si256(); \
207 __m256i vsum = _mm256_setzero_si256(); \
208 for (int i = 0; i < (bh / uh); i++) { \
209 __m256i vsum16; \
210 variance##bw##_avx2(src, src_stride, ref, ref_stride, uh, &vsse, \
211 &vsum16); \
212 vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16)); \
213 src += uh * src_stride; \
214 ref += uh * ref_stride; \
215 } \
216 const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum); \
217 const int sum = variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse); \
218 return *sse - (unsigned int)(((int64_t)sum * sum) >> bits); \
219 }
220
221 AOM_VAR_LOOP_AVX2(64, 64, 12, 32); // 64x32 * ( 64/32)
222 AOM_VAR_LOOP_AVX2(64, 128, 13, 32); // 64x32 * (128/32)
223 AOM_VAR_LOOP_AVX2(128, 64, 13, 16); // 128x16 * ( 64/16)
224 AOM_VAR_LOOP_AVX2(128, 128, 14, 16); // 128x16 * (128/16)
225
aom_mse16x16_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,unsigned int * sse)226 unsigned int aom_mse16x16_avx2(const uint8_t *src, int src_stride,
227 const uint8_t *ref, int ref_stride,
228 unsigned int *sse) {
229 aom_variance16x16_avx2(src, src_stride, ref, ref_stride, sse);
230 return *sse;
231 }
232
233 unsigned int aom_sub_pixel_variance32xh_avx2(const uint8_t *src, int src_stride,
234 int x_offset, int y_offset,
235 const uint8_t *dst, int dst_stride,
236 int height, unsigned int *sse);
237
238 unsigned int aom_sub_pixel_avg_variance32xh_avx2(
239 const uint8_t *src, int src_stride, int x_offset, int y_offset,
240 const uint8_t *dst, int dst_stride, const uint8_t *sec, int sec_stride,
241 int height, unsigned int *sseptr);
242
243 #define AOM_SUB_PIXEL_VAR_AVX2(w, h, wf, wlog2, hlog2) \
244 unsigned int aom_sub_pixel_variance##w##x##h##_avx2( \
245 const uint8_t *src, int src_stride, int x_offset, int y_offset, \
246 const uint8_t *dst, int dst_stride, unsigned int *sse_ptr) { \
247 /*Avoid overflow in helper by capping height.*/ \
248 const int hf = AOMMIN(h, 64); \
249 unsigned int sse = 0; \
250 int se = 0; \
251 for (int i = 0; i < (w / wf); ++i) { \
252 const uint8_t *src_ptr = src; \
253 const uint8_t *dst_ptr = dst; \
254 for (int j = 0; j < (h / hf); ++j) { \
255 unsigned int sse2; \
256 const int se2 = aom_sub_pixel_variance##wf##xh_avx2( \
257 src_ptr, src_stride, x_offset, y_offset, dst_ptr, dst_stride, hf, \
258 &sse2); \
259 dst_ptr += hf * dst_stride; \
260 src_ptr += hf * src_stride; \
261 se += se2; \
262 sse += sse2; \
263 } \
264 src += wf; \
265 dst += wf; \
266 } \
267 *sse_ptr = sse; \
268 return sse - (unsigned int)(((int64_t)se * se) >> (wlog2 + hlog2)); \
269 }
270
271 AOM_SUB_PIXEL_VAR_AVX2(128, 128, 32, 7, 7);
272 AOM_SUB_PIXEL_VAR_AVX2(128, 64, 32, 7, 6);
273 AOM_SUB_PIXEL_VAR_AVX2(64, 128, 32, 6, 7);
274 AOM_SUB_PIXEL_VAR_AVX2(64, 64, 32, 6, 6);
275 AOM_SUB_PIXEL_VAR_AVX2(64, 32, 32, 6, 5);
276 AOM_SUB_PIXEL_VAR_AVX2(32, 64, 32, 5, 6);
277 AOM_SUB_PIXEL_VAR_AVX2(32, 32, 32, 5, 5);
278 AOM_SUB_PIXEL_VAR_AVX2(32, 16, 32, 5, 4);
279
280 #define AOM_SUB_PIXEL_AVG_VAR_AVX2(w, h, wf, wlog2, hlog2) \
281 unsigned int aom_sub_pixel_avg_variance##w##x##h##_avx2( \
282 const uint8_t *src, int src_stride, int x_offset, int y_offset, \
283 const uint8_t *dst, int dst_stride, unsigned int *sse_ptr, \
284 const uint8_t *sec) { \
285 /*Avoid overflow in helper by capping height.*/ \
286 const int hf = AOMMIN(h, 64); \
287 unsigned int sse = 0; \
288 int se = 0; \
289 for (int i = 0; i < (w / wf); ++i) { \
290 const uint8_t *src_ptr = src; \
291 const uint8_t *dst_ptr = dst; \
292 const uint8_t *sec_ptr = sec; \
293 for (int j = 0; j < (h / hf); ++j) { \
294 unsigned int sse2; \
295 const int se2 = aom_sub_pixel_avg_variance##wf##xh_avx2( \
296 src_ptr, src_stride, x_offset, y_offset, dst_ptr, dst_stride, \
297 sec_ptr, w, hf, &sse2); \
298 dst_ptr += hf * dst_stride; \
299 src_ptr += hf * src_stride; \
300 sec_ptr += hf * w; \
301 se += se2; \
302 sse += sse2; \
303 } \
304 src += wf; \
305 dst += wf; \
306 sec += wf; \
307 } \
308 *sse_ptr = sse; \
309 return sse - (unsigned int)(((int64_t)se * se) >> (wlog2 + hlog2)); \
310 }
311
312 AOM_SUB_PIXEL_AVG_VAR_AVX2(128, 128, 32, 7, 7);
313 AOM_SUB_PIXEL_AVG_VAR_AVX2(128, 64, 32, 7, 6);
314 AOM_SUB_PIXEL_AVG_VAR_AVX2(64, 128, 32, 6, 7);
315 AOM_SUB_PIXEL_AVG_VAR_AVX2(64, 64, 32, 6, 6);
316 AOM_SUB_PIXEL_AVG_VAR_AVX2(64, 32, 32, 6, 5);
317 AOM_SUB_PIXEL_AVG_VAR_AVX2(32, 64, 32, 5, 6);
318 AOM_SUB_PIXEL_AVG_VAR_AVX2(32, 32, 32, 5, 5);
319 AOM_SUB_PIXEL_AVG_VAR_AVX2(32, 16, 32, 5, 4);
320
mm256_loadu2(const uint8_t * p0,const uint8_t * p1)321 static INLINE __m256i mm256_loadu2(const uint8_t *p0, const uint8_t *p1) {
322 const __m256i d =
323 _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
324 return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
325 }
326
mm256_loadu2_16(const uint16_t * p0,const uint16_t * p1)327 static INLINE __m256i mm256_loadu2_16(const uint16_t *p0, const uint16_t *p1) {
328 const __m256i d =
329 _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
330 return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
331 }
332
comp_mask_pred_line_avx2(const __m256i s0,const __m256i s1,const __m256i a,uint8_t * comp_pred)333 static INLINE void comp_mask_pred_line_avx2(const __m256i s0, const __m256i s1,
334 const __m256i a,
335 uint8_t *comp_pred) {
336 const __m256i alpha_max = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
337 const int16_t round_bits = 15 - AOM_BLEND_A64_ROUND_BITS;
338 const __m256i round_offset = _mm256_set1_epi16(1 << (round_bits));
339
340 const __m256i ma = _mm256_sub_epi8(alpha_max, a);
341
342 const __m256i ssAL = _mm256_unpacklo_epi8(s0, s1);
343 const __m256i aaAL = _mm256_unpacklo_epi8(a, ma);
344 const __m256i ssAH = _mm256_unpackhi_epi8(s0, s1);
345 const __m256i aaAH = _mm256_unpackhi_epi8(a, ma);
346
347 const __m256i blendAL = _mm256_maddubs_epi16(ssAL, aaAL);
348 const __m256i blendAH = _mm256_maddubs_epi16(ssAH, aaAH);
349 const __m256i roundAL = _mm256_mulhrs_epi16(blendAL, round_offset);
350 const __m256i roundAH = _mm256_mulhrs_epi16(blendAH, round_offset);
351
352 const __m256i roundA = _mm256_packus_epi16(roundAL, roundAH);
353 _mm256_storeu_si256((__m256i *)(comp_pred), roundA);
354 }
355
aom_comp_mask_pred_avx2(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)356 void aom_comp_mask_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
357 int height, const uint8_t *ref, int ref_stride,
358 const uint8_t *mask, int mask_stride,
359 int invert_mask) {
360 int i = 0;
361 const uint8_t *src0 = invert_mask ? pred : ref;
362 const uint8_t *src1 = invert_mask ? ref : pred;
363 const int stride0 = invert_mask ? width : ref_stride;
364 const int stride1 = invert_mask ? ref_stride : width;
365 if (width == 8) {
366 comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
367 mask, mask_stride);
368 } else if (width == 16) {
369 do {
370 const __m256i sA0 = mm256_loadu2(src0 + stride0, src0);
371 const __m256i sA1 = mm256_loadu2(src1 + stride1, src1);
372 const __m256i aA = mm256_loadu2(mask + mask_stride, mask);
373 src0 += (stride0 << 1);
374 src1 += (stride1 << 1);
375 mask += (mask_stride << 1);
376 const __m256i sB0 = mm256_loadu2(src0 + stride0, src0);
377 const __m256i sB1 = mm256_loadu2(src1 + stride1, src1);
378 const __m256i aB = mm256_loadu2(mask + mask_stride, mask);
379 src0 += (stride0 << 1);
380 src1 += (stride1 << 1);
381 mask += (mask_stride << 1);
382 // comp_pred's stride == width == 16
383 comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
384 comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
385 comp_pred += (16 << 2);
386 i += 4;
387 } while (i < height);
388 } else { // for width == 32
389 do {
390 const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0));
391 const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1));
392 const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask));
393
394 const __m256i sB0 = _mm256_lddqu_si256((const __m256i *)(src0 + stride0));
395 const __m256i sB1 = _mm256_lddqu_si256((const __m256i *)(src1 + stride1));
396 const __m256i aB =
397 _mm256_lddqu_si256((const __m256i *)(mask + mask_stride));
398
399 comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
400 comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
401 comp_pred += (32 << 1);
402
403 src0 += (stride0 << 1);
404 src1 += (stride1 << 1);
405 mask += (mask_stride << 1);
406 i += 2;
407 } while (i < height);
408 }
409 }
410
highbd_comp_mask_pred_line_avx2(const __m256i s0,const __m256i s1,const __m256i a)411 static INLINE __m256i highbd_comp_mask_pred_line_avx2(const __m256i s0,
412 const __m256i s1,
413 const __m256i a) {
414 const __m256i alpha_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
415 const __m256i round_const =
416 _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
417 const __m256i a_inv = _mm256_sub_epi16(alpha_max, a);
418
419 const __m256i s_lo = _mm256_unpacklo_epi16(s0, s1);
420 const __m256i a_lo = _mm256_unpacklo_epi16(a, a_inv);
421 const __m256i pred_lo = _mm256_madd_epi16(s_lo, a_lo);
422 const __m256i pred_l = _mm256_srai_epi32(
423 _mm256_add_epi32(pred_lo, round_const), AOM_BLEND_A64_ROUND_BITS);
424
425 const __m256i s_hi = _mm256_unpackhi_epi16(s0, s1);
426 const __m256i a_hi = _mm256_unpackhi_epi16(a, a_inv);
427 const __m256i pred_hi = _mm256_madd_epi16(s_hi, a_hi);
428 const __m256i pred_h = _mm256_srai_epi32(
429 _mm256_add_epi32(pred_hi, round_const), AOM_BLEND_A64_ROUND_BITS);
430
431 const __m256i comp = _mm256_packs_epi32(pred_l, pred_h);
432
433 return comp;
434 }
435
aom_highbd_comp_mask_pred_avx2(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)436 void aom_highbd_comp_mask_pred_avx2(uint8_t *comp_pred8, const uint8_t *pred8,
437 int width, int height, const uint8_t *ref8,
438 int ref_stride, const uint8_t *mask,
439 int mask_stride, int invert_mask) {
440 int i = 0;
441 uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
442 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
443 uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
444 const uint16_t *src0 = invert_mask ? pred : ref;
445 const uint16_t *src1 = invert_mask ? ref : pred;
446 const int stride0 = invert_mask ? width : ref_stride;
447 const int stride1 = invert_mask ? ref_stride : width;
448 const __m256i zero = _mm256_setzero_si256();
449
450 if (width == 8) {
451 do {
452 const __m256i s0 = mm256_loadu2_16(src0 + stride0, src0);
453 const __m256i s1 = mm256_loadu2_16(src1 + stride1, src1);
454
455 const __m128i m_l = _mm_loadl_epi64((const __m128i *)mask);
456 const __m128i m_h = _mm_loadl_epi64((const __m128i *)(mask + 8));
457
458 __m256i m = _mm256_castsi128_si256(m_l);
459 m = _mm256_insertf128_si256(m, m_h, 1);
460 const __m256i m_16 = _mm256_unpacklo_epi8(m, zero);
461
462 const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
463
464 _mm_storeu_si128((__m128i *)(comp_pred), _mm256_castsi256_si128(comp));
465
466 _mm_storeu_si128((__m128i *)(comp_pred + width),
467 _mm256_extractf128_si256(comp, 1));
468
469 src0 += (stride0 << 1);
470 src1 += (stride1 << 1);
471 mask += (mask_stride << 1);
472 comp_pred += (width << 1);
473 i += 2;
474 } while (i < height);
475 } else if (width == 16) {
476 do {
477 const __m256i s0 = _mm256_loadu_si256((const __m256i *)(src0));
478 const __m256i s1 = _mm256_loadu_si256((const __m256i *)(src1));
479 const __m256i m_16 =
480 _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)mask));
481
482 const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
483
484 _mm256_storeu_si256((__m256i *)comp_pred, comp);
485
486 src0 += stride0;
487 src1 += stride1;
488 mask += mask_stride;
489 comp_pred += width;
490 i += 1;
491 } while (i < height);
492 } else if (width == 32) {
493 do {
494 const __m256i s0 = _mm256_loadu_si256((const __m256i *)src0);
495 const __m256i s2 = _mm256_loadu_si256((const __m256i *)(src0 + 16));
496 const __m256i s1 = _mm256_loadu_si256((const __m256i *)src1);
497 const __m256i s3 = _mm256_loadu_si256((const __m256i *)(src1 + 16));
498
499 const __m256i m01_16 =
500 _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)mask));
501 const __m256i m23_16 =
502 _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(mask + 16)));
503
504 const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m01_16);
505 const __m256i comp1 = highbd_comp_mask_pred_line_avx2(s2, s3, m23_16);
506
507 _mm256_storeu_si256((__m256i *)comp_pred, comp);
508 _mm256_storeu_si256((__m256i *)(comp_pred + 16), comp1);
509
510 src0 += stride0;
511 src1 += stride1;
512 mask += mask_stride;
513 comp_pred += width;
514 i += 1;
515 } while (i < height);
516 }
517 }
518