1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <smmintrin.h>
14 
15 #include "config/aom_config.h"
16 
17 #include "aom_ports/mem.h"
18 #include "aom/aom_integer.h"
19 #include "aom_dsp/x86/synonyms.h"
20 
summary_all_sse4(const __m128i * sum_all)21 static INLINE int64_t summary_all_sse4(const __m128i *sum_all) {
22   int64_t sum;
23   const __m128i sum0 = _mm_cvtepu32_epi64(*sum_all);
24   const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum_all, 8));
25   const __m128i sum_2x64 = _mm_add_epi64(sum0, sum1);
26   const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
27   xx_storel_64(&sum, sum_1x64);
28   return sum;
29 }
30 
summary_32_sse4(const __m128i * sum32,__m128i * sum64)31 static INLINE void summary_32_sse4(const __m128i *sum32, __m128i *sum64) {
32   const __m128i sum0 = _mm_cvtepu32_epi64(*sum32);
33   const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum32, 8));
34   *sum64 = _mm_add_epi64(sum0, *sum64);
35   *sum64 = _mm_add_epi64(sum1, *sum64);
36 }
37 
sse_w16_sse4_1(__m128i * sum,const uint8_t * a,const uint8_t * b)38 static INLINE void sse_w16_sse4_1(__m128i *sum, const uint8_t *a,
39                                   const uint8_t *b) {
40   const __m128i v_a0 = xx_loadu_128(a);
41   const __m128i v_b0 = xx_loadu_128(b);
42   const __m128i v_a00_w = _mm_cvtepu8_epi16(v_a0);
43   const __m128i v_a01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_a0, 8));
44   const __m128i v_b00_w = _mm_cvtepu8_epi16(v_b0);
45   const __m128i v_b01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_b0, 8));
46   const __m128i v_d00_w = _mm_sub_epi16(v_a00_w, v_b00_w);
47   const __m128i v_d01_w = _mm_sub_epi16(v_a01_w, v_b01_w);
48   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d00_w, v_d00_w));
49   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d01_w, v_d01_w));
50 }
51 
aom_sse4x2_sse4_1(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,__m128i * sum)52 static INLINE void aom_sse4x2_sse4_1(const uint8_t *a, int a_stride,
53                                      const uint8_t *b, int b_stride,
54                                      __m128i *sum) {
55   const __m128i v_a0 = xx_loadl_32(a);
56   const __m128i v_a1 = xx_loadl_32(a + a_stride);
57   const __m128i v_b0 = xx_loadl_32(b);
58   const __m128i v_b1 = xx_loadl_32(b + b_stride);
59   const __m128i v_a_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_a0, v_a1));
60   const __m128i v_b_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_b0, v_b1));
61   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
62   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
63 }
aom_sse8_sse4_1(const uint8_t * a,const uint8_t * b,__m128i * sum)64 static INLINE void aom_sse8_sse4_1(const uint8_t *a, const uint8_t *b,
65                                    __m128i *sum) {
66   const __m128i v_a0 = xx_loadl_64(a);
67   const __m128i v_b0 = xx_loadl_64(b);
68   const __m128i v_a_w = _mm_cvtepu8_epi16(v_a0);
69   const __m128i v_b_w = _mm_cvtepu8_epi16(v_b0);
70   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
71   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
72 }
73 
aom_sse_sse4_1(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int width,int height)74 int64_t aom_sse_sse4_1(const uint8_t *a, int a_stride, const uint8_t *b,
75                        int b_stride, int width, int height) {
76   int y = 0;
77   int64_t sse = 0;
78   __m128i sum = _mm_setzero_si128();
79   switch (width) {
80     case 4:
81       do {
82         aom_sse4x2_sse4_1(a, a_stride, b, b_stride, &sum);
83         a += a_stride << 1;
84         b += b_stride << 1;
85         y += 2;
86       } while (y < height);
87       sse = summary_all_sse4(&sum);
88       break;
89     case 8:
90       do {
91         aom_sse8_sse4_1(a, b, &sum);
92         a += a_stride;
93         b += b_stride;
94         y += 1;
95       } while (y < height);
96       sse = summary_all_sse4(&sum);
97       break;
98     case 16:
99       do {
100         sse_w16_sse4_1(&sum, a, b);
101         a += a_stride;
102         b += b_stride;
103         y += 1;
104       } while (y < height);
105       sse = summary_all_sse4(&sum);
106       break;
107     case 32:
108       do {
109         sse_w16_sse4_1(&sum, a, b);
110         sse_w16_sse4_1(&sum, a + 16, b + 16);
111         a += a_stride;
112         b += b_stride;
113         y += 1;
114       } while (y < height);
115       sse = summary_all_sse4(&sum);
116       break;
117     case 64:
118       do {
119         sse_w16_sse4_1(&sum, a, b);
120         sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
121         sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
122         sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
123         a += a_stride;
124         b += b_stride;
125         y += 1;
126       } while (y < height);
127       sse = summary_all_sse4(&sum);
128       break;
129     case 128:
130       do {
131         sse_w16_sse4_1(&sum, a, b);
132         sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
133         sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
134         sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
135         sse_w16_sse4_1(&sum, a + 16 * 4, b + 16 * 4);
136         sse_w16_sse4_1(&sum, a + 16 * 5, b + 16 * 5);
137         sse_w16_sse4_1(&sum, a + 16 * 6, b + 16 * 6);
138         sse_w16_sse4_1(&sum, a + 16 * 7, b + 16 * 7);
139         a += a_stride;
140         b += b_stride;
141         y += 1;
142       } while (y < height);
143       sse = summary_all_sse4(&sum);
144       break;
145     default:
146       if (width & 0x07) {
147         do {
148           int i = 0;
149           do {
150             aom_sse8_sse4_1(a + i, b + i, &sum);
151             aom_sse8_sse4_1(a + i + a_stride, b + i + b_stride, &sum);
152             i += 8;
153           } while (i + 4 < width);
154           aom_sse4x2_sse4_1(a + i, a_stride, b + i, b_stride, &sum);
155           a += (a_stride << 1);
156           b += (b_stride << 1);
157           y += 2;
158         } while (y < height);
159       } else {
160         do {
161           int i = 0;
162           do {
163             aom_sse8_sse4_1(a + i, b + i, &sum);
164             i += 8;
165           } while (i < width);
166           a += a_stride;
167           b += b_stride;
168           y += 1;
169         } while (y < height);
170       }
171       sse = summary_all_sse4(&sum);
172       break;
173   }
174 
175   return sse;
176 }
177 
highbd_sse_w4x2_sse4_1(__m128i * sum,const uint16_t * a,int a_stride,const uint16_t * b,int b_stride)178 static INLINE void highbd_sse_w4x2_sse4_1(__m128i *sum, const uint16_t *a,
179                                           int a_stride, const uint16_t *b,
180                                           int b_stride) {
181   const __m128i v_a0 = xx_loadl_64(a);
182   const __m128i v_a1 = xx_loadl_64(a + a_stride);
183   const __m128i v_b0 = xx_loadl_64(b);
184   const __m128i v_b1 = xx_loadl_64(b + b_stride);
185   const __m128i v_a_w = _mm_unpacklo_epi64(v_a0, v_a1);
186   const __m128i v_b_w = _mm_unpacklo_epi64(v_b0, v_b1);
187   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
188   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
189 }
190 
highbd_sse_w8_sse4_1(__m128i * sum,const uint16_t * a,const uint16_t * b)191 static INLINE void highbd_sse_w8_sse4_1(__m128i *sum, const uint16_t *a,
192                                         const uint16_t *b) {
193   const __m128i v_a_w = xx_loadu_128(a);
194   const __m128i v_b_w = xx_loadu_128(b);
195   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
196   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
197 }
198 
aom_highbd_sse_sse4_1(const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,int width,int height)199 int64_t aom_highbd_sse_sse4_1(const uint8_t *a8, int a_stride,
200                               const uint8_t *b8, int b_stride, int width,
201                               int height) {
202   int32_t y = 0;
203   int64_t sse = 0;
204   uint16_t *a = CONVERT_TO_SHORTPTR(a8);
205   uint16_t *b = CONVERT_TO_SHORTPTR(b8);
206   __m128i sum = _mm_setzero_si128();
207   switch (width) {
208     case 4:
209       do {
210         highbd_sse_w4x2_sse4_1(&sum, a, a_stride, b, b_stride);
211         a += a_stride << 1;
212         b += b_stride << 1;
213         y += 2;
214       } while (y < height);
215       sse = summary_all_sse4(&sum);
216       break;
217     case 8:
218       do {
219         highbd_sse_w8_sse4_1(&sum, a, b);
220         a += a_stride;
221         b += b_stride;
222         y += 1;
223       } while (y < height);
224       sse = summary_all_sse4(&sum);
225       break;
226     case 16:
227       do {
228         int l = 0;
229         __m128i sum32 = _mm_setzero_si128();
230         do {
231           highbd_sse_w8_sse4_1(&sum32, a, b);
232           highbd_sse_w8_sse4_1(&sum32, a + 8, b + 8);
233           a += a_stride;
234           b += b_stride;
235           l += 1;
236         } while (l < 64 && l < (height - y));
237         summary_32_sse4(&sum32, &sum);
238         y += 64;
239       } while (y < height);
240       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
241       break;
242     case 32:
243       do {
244         int l = 0;
245         __m128i sum32 = _mm_setzero_si128();
246         do {
247           highbd_sse_w8_sse4_1(&sum32, a, b);
248           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
249           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
250           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
251           a += a_stride;
252           b += b_stride;
253           l += 1;
254         } while (l < 32 && l < (height - y));
255         summary_32_sse4(&sum32, &sum);
256         y += 32;
257       } while (y < height);
258       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
259       break;
260     case 64:
261       do {
262         int l = 0;
263         __m128i sum32 = _mm_setzero_si128();
264         do {
265           highbd_sse_w8_sse4_1(&sum32, a, b);
266           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
267           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
268           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
269           highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4);
270           highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5);
271           highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6);
272           highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7);
273           a += a_stride;
274           b += b_stride;
275           l += 1;
276         } while (l < 16 && l < (height - y));
277         summary_32_sse4(&sum32, &sum);
278         y += 16;
279       } while (y < height);
280       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
281       break;
282     case 128:
283       do {
284         int l = 0;
285         __m128i sum32 = _mm_setzero_si128();
286         do {
287           highbd_sse_w8_sse4_1(&sum32, a, b);
288           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
289           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
290           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
291           highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4);
292           highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5);
293           highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6);
294           highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7);
295           highbd_sse_w8_sse4_1(&sum32, a + 8 * 8, b + 8 * 8);
296           highbd_sse_w8_sse4_1(&sum32, a + 8 * 9, b + 8 * 9);
297           highbd_sse_w8_sse4_1(&sum32, a + 8 * 10, b + 8 * 10);
298           highbd_sse_w8_sse4_1(&sum32, a + 8 * 11, b + 8 * 11);
299           highbd_sse_w8_sse4_1(&sum32, a + 8 * 12, b + 8 * 12);
300           highbd_sse_w8_sse4_1(&sum32, a + 8 * 13, b + 8 * 13);
301           highbd_sse_w8_sse4_1(&sum32, a + 8 * 14, b + 8 * 14);
302           highbd_sse_w8_sse4_1(&sum32, a + 8 * 15, b + 8 * 15);
303           a += a_stride;
304           b += b_stride;
305           l += 1;
306         } while (l < 8 && l < (height - y));
307         summary_32_sse4(&sum32, &sum);
308         y += 8;
309       } while (y < height);
310       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
311       break;
312     default:
313       if (width & 0x7) {
314         do {
315           __m128i sum32 = _mm_setzero_si128();
316           int i = 0;
317           do {
318             highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
319             highbd_sse_w8_sse4_1(&sum32, a + i + a_stride, b + i + b_stride);
320             i += 8;
321           } while (i + 4 < width);
322           highbd_sse_w4x2_sse4_1(&sum32, a + i, a_stride, b + i, b_stride);
323           a += (a_stride << 1);
324           b += (b_stride << 1);
325           y += 2;
326           summary_32_sse4(&sum32, &sum);
327         } while (y < height);
328       } else {
329         do {
330           int l = 0;
331           __m128i sum32 = _mm_setzero_si128();
332           do {
333             int i = 0;
334             do {
335               highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
336               i += 8;
337             } while (i < width);
338             a += a_stride;
339             b += b_stride;
340             l += 1;
341           } while (l < 8 && l < (height - y));
342           summary_32_sse4(&sum32, &sum);
343           y += 8;
344         } while (y < height);
345       }
346       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
347       break;
348   }
349   return sse;
350 }
351