1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #ifndef AOM_AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H_
13 #define AOM_AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H_
14 
15 #include <stdlib.h>
16 #include <string.h>
17 #include <tmmintrin.h>
18 
19 #include "config/aom_config.h"
20 #include "config/aom_dsp_rtcd.h"
21 
22 #include "aom_dsp/blend.h"
23 
comp_mask_pred_16_ssse3(const uint8_t * src0,const uint8_t * src1,const uint8_t * mask,uint8_t * dst)24 static INLINE void comp_mask_pred_16_ssse3(const uint8_t *src0,
25                                            const uint8_t *src1,
26                                            const uint8_t *mask, uint8_t *dst) {
27   const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
28   const __m128i round_offset =
29       _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
30 
31   const __m128i sA0 = _mm_lddqu_si128((const __m128i *)(src0));
32   const __m128i sA1 = _mm_lddqu_si128((const __m128i *)(src1));
33   const __m128i aA = _mm_load_si128((const __m128i *)(mask));
34 
35   const __m128i maA = _mm_sub_epi8(alpha_max, aA);
36 
37   const __m128i ssAL = _mm_unpacklo_epi8(sA0, sA1);
38   const __m128i aaAL = _mm_unpacklo_epi8(aA, maA);
39   const __m128i ssAH = _mm_unpackhi_epi8(sA0, sA1);
40   const __m128i aaAH = _mm_unpackhi_epi8(aA, maA);
41 
42   const __m128i blendAL = _mm_maddubs_epi16(ssAL, aaAL);
43   const __m128i blendAH = _mm_maddubs_epi16(ssAH, aaAH);
44 
45   const __m128i roundAL = _mm_mulhrs_epi16(blendAL, round_offset);
46   const __m128i roundAH = _mm_mulhrs_epi16(blendAH, round_offset);
47   _mm_store_si128((__m128i *)dst, _mm_packus_epi16(roundAL, roundAH));
48 }
49 
comp_mask_pred_8_ssse3(uint8_t * comp_pred,int height,const uint8_t * src0,int stride0,const uint8_t * src1,int stride1,const uint8_t * mask,int mask_stride)50 static INLINE void comp_mask_pred_8_ssse3(uint8_t *comp_pred, int height,
51                                           const uint8_t *src0, int stride0,
52                                           const uint8_t *src1, int stride1,
53                                           const uint8_t *mask,
54                                           int mask_stride) {
55   int i = 0;
56   const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
57   const __m128i round_offset =
58       _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
59   do {
60     // odd line A
61     const __m128i sA0 = _mm_loadl_epi64((const __m128i *)(src0));
62     const __m128i sA1 = _mm_loadl_epi64((const __m128i *)(src1));
63     const __m128i aA = _mm_loadl_epi64((const __m128i *)(mask));
64     // even line B
65     const __m128i sB0 = _mm_loadl_epi64((const __m128i *)(src0 + stride0));
66     const __m128i sB1 = _mm_loadl_epi64((const __m128i *)(src1 + stride1));
67     const __m128i a = _mm_castps_si128(_mm_loadh_pi(
68         _mm_castsi128_ps(aA), (const __m64 *)(mask + mask_stride)));
69 
70     const __m128i ssA = _mm_unpacklo_epi8(sA0, sA1);
71     const __m128i ssB = _mm_unpacklo_epi8(sB0, sB1);
72 
73     const __m128i ma = _mm_sub_epi8(alpha_max, a);
74     const __m128i aaA = _mm_unpacklo_epi8(a, ma);
75     const __m128i aaB = _mm_unpackhi_epi8(a, ma);
76 
77     const __m128i blendA = _mm_maddubs_epi16(ssA, aaA);
78     const __m128i blendB = _mm_maddubs_epi16(ssB, aaB);
79     const __m128i roundA = _mm_mulhrs_epi16(blendA, round_offset);
80     const __m128i roundB = _mm_mulhrs_epi16(blendB, round_offset);
81     const __m128i round = _mm_packus_epi16(roundA, roundB);
82     // comp_pred's stride == width == 8
83     _mm_store_si128((__m128i *)(comp_pred), round);
84     comp_pred += (8 << 1);
85     src0 += (stride0 << 1);
86     src1 += (stride1 << 1);
87     mask += (mask_stride << 1);
88     i += 2;
89   } while (i < height);
90 }
91 
92 #endif  // AOM_AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H_
93