1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/loop_restoration.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_AVX2
19 #include <immintrin.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/common.h"
28 #include "src/dsp/constants.h"
29 #include "src/dsp/dsp.h"
30 #include "src/dsp/x86/common_avx2.h"
31 #include "src/utils/common.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace low_bitdepth {
37 namespace {
38 
WienerHorizontalClip(const __m256i s[2],const __m256i s_3x128,int16_t * const wiener_buffer)39 inline void WienerHorizontalClip(const __m256i s[2], const __m256i s_3x128,
40                                  int16_t* const wiener_buffer) {
41   constexpr int offset =
42       1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
43   constexpr int limit =
44       (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1;
45   const __m256i offsets = _mm256_set1_epi16(-offset);
46   const __m256i limits = _mm256_set1_epi16(limit - offset);
47   const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsHorizontal - 1));
48   // The sum range here is [-128 * 255, 90 * 255].
49   const __m256i madd = _mm256_add_epi16(s[0], s[1]);
50   const __m256i sum = _mm256_add_epi16(madd, round);
51   const __m256i rounded_sum0 =
52       _mm256_srai_epi16(sum, kInterRoundBitsHorizontal);
53   // Add back scaled down offset correction.
54   const __m256i rounded_sum1 = _mm256_add_epi16(rounded_sum0, s_3x128);
55   const __m256i d0 = _mm256_max_epi16(rounded_sum1, offsets);
56   const __m256i d1 = _mm256_min_epi16(d0, limits);
57   StoreAligned32(wiener_buffer, d1);
58 }
59 
60 // Using _mm256_alignr_epi8() is about 8% faster than loading all and unpacking,
61 // because the compiler generates redundant code when loading all and unpacking.
WienerHorizontalTap7Kernel(const __m256i s[2],const __m256i filter[4],int16_t * const wiener_buffer)62 inline void WienerHorizontalTap7Kernel(const __m256i s[2],
63                                        const __m256i filter[4],
64                                        int16_t* const wiener_buffer) {
65   const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1);
66   const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5);
67   const auto s45 = _mm256_alignr_epi8(s[1], s[0], 9);
68   const auto s67 = _mm256_alignr_epi8(s[1], s[0], 13);
69   __m256i madds[4];
70   madds[0] = _mm256_maddubs_epi16(s01, filter[0]);
71   madds[1] = _mm256_maddubs_epi16(s23, filter[1]);
72   madds[2] = _mm256_maddubs_epi16(s45, filter[2]);
73   madds[3] = _mm256_maddubs_epi16(s67, filter[3]);
74   madds[0] = _mm256_add_epi16(madds[0], madds[2]);
75   madds[1] = _mm256_add_epi16(madds[1], madds[3]);
76   const __m256i s_3x128 = _mm256_slli_epi16(_mm256_srli_epi16(s23, 8),
77                                             7 - kInterRoundBitsHorizontal);
78   WienerHorizontalClip(madds, s_3x128, wiener_buffer);
79 }
80 
WienerHorizontalTap5Kernel(const __m256i s[2],const __m256i filter[3],int16_t * const wiener_buffer)81 inline void WienerHorizontalTap5Kernel(const __m256i s[2],
82                                        const __m256i filter[3],
83                                        int16_t* const wiener_buffer) {
84   const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1);
85   const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5);
86   const auto s45 = _mm256_alignr_epi8(s[1], s[0], 9);
87   __m256i madds[3];
88   madds[0] = _mm256_maddubs_epi16(s01, filter[0]);
89   madds[1] = _mm256_maddubs_epi16(s23, filter[1]);
90   madds[2] = _mm256_maddubs_epi16(s45, filter[2]);
91   madds[0] = _mm256_add_epi16(madds[0], madds[2]);
92   const __m256i s_3x128 = _mm256_srli_epi16(_mm256_slli_epi16(s23, 8),
93                                             kInterRoundBitsHorizontal + 1);
94   WienerHorizontalClip(madds, s_3x128, wiener_buffer);
95 }
96 
WienerHorizontalTap3Kernel(const __m256i s[2],const __m256i filter[2],int16_t * const wiener_buffer)97 inline void WienerHorizontalTap3Kernel(const __m256i s[2],
98                                        const __m256i filter[2],
99                                        int16_t* const wiener_buffer) {
100   const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1);
101   const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5);
102   __m256i madds[2];
103   madds[0] = _mm256_maddubs_epi16(s01, filter[0]);
104   madds[1] = _mm256_maddubs_epi16(s23, filter[1]);
105   const __m256i s_3x128 = _mm256_slli_epi16(_mm256_srli_epi16(s01, 8),
106                                             7 - kInterRoundBitsHorizontal);
107   WienerHorizontalClip(madds, s_3x128, wiener_buffer);
108 }
109 
WienerHorizontalTap7(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i coefficients,int16_t ** const wiener_buffer)110 inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
111                                  const ptrdiff_t width, const int height,
112                                  const __m256i coefficients,
113                                  int16_t** const wiener_buffer) {
114   __m256i filter[4];
115   filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0100));
116   filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302));
117   filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0102));
118   filter[3] = _mm256_shuffle_epi8(
119       coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8000)));
120   for (int y = height; y != 0; --y) {
121     __m256i s = LoadUnaligned32(src);
122     __m256i ss[4];
123     ss[0] = _mm256_unpacklo_epi8(s, s);
124     ptrdiff_t x = 0;
125     do {
126       ss[1] = _mm256_unpackhi_epi8(s, s);
127       s = LoadUnaligned32(src + x + 32);
128       ss[3] = _mm256_unpacklo_epi8(s, s);
129       ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21);
130       WienerHorizontalTap7Kernel(ss + 0, filter, *wiener_buffer + x + 0);
131       WienerHorizontalTap7Kernel(ss + 1, filter, *wiener_buffer + x + 16);
132       ss[0] = ss[3];
133       x += 32;
134     } while (x < width);
135     src += src_stride;
136     *wiener_buffer += width;
137   }
138 }
139 
WienerHorizontalTap5(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i coefficients,int16_t ** const wiener_buffer)140 inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
141                                  const ptrdiff_t width, const int height,
142                                  const __m256i coefficients,
143                                  int16_t** const wiener_buffer) {
144   __m256i filter[3];
145   filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0201));
146   filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0203));
147   filter[2] = _mm256_shuffle_epi8(
148       coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8001)));
149   for (int y = height; y != 0; --y) {
150     __m256i s = LoadUnaligned32(src);
151     __m256i ss[4];
152     ss[0] = _mm256_unpacklo_epi8(s, s);
153     ptrdiff_t x = 0;
154     do {
155       ss[1] = _mm256_unpackhi_epi8(s, s);
156       s = LoadUnaligned32(src + x + 32);
157       ss[3] = _mm256_unpacklo_epi8(s, s);
158       ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21);
159       WienerHorizontalTap5Kernel(ss + 0, filter, *wiener_buffer + x + 0);
160       WienerHorizontalTap5Kernel(ss + 1, filter, *wiener_buffer + x + 16);
161       ss[0] = ss[3];
162       x += 32;
163     } while (x < width);
164     src += src_stride;
165     *wiener_buffer += width;
166   }
167 }
168 
WienerHorizontalTap3(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i coefficients,int16_t ** const wiener_buffer)169 inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
170                                  const ptrdiff_t width, const int height,
171                                  const __m256i coefficients,
172                                  int16_t** const wiener_buffer) {
173   __m256i filter[2];
174   filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302));
175   filter[1] = _mm256_shuffle_epi8(
176       coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8002)));
177   for (int y = height; y != 0; --y) {
178     __m256i s = LoadUnaligned32(src);
179     __m256i ss[4];
180     ss[0] = _mm256_unpacklo_epi8(s, s);
181     ptrdiff_t x = 0;
182     do {
183       ss[1] = _mm256_unpackhi_epi8(s, s);
184       s = LoadUnaligned32(src + x + 32);
185       ss[3] = _mm256_unpacklo_epi8(s, s);
186       ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21);
187       WienerHorizontalTap3Kernel(ss + 0, filter, *wiener_buffer + x + 0);
188       WienerHorizontalTap3Kernel(ss + 1, filter, *wiener_buffer + x + 16);
189       ss[0] = ss[3];
190       x += 32;
191     } while (x < width);
192     src += src_stride;
193     *wiener_buffer += width;
194   }
195 }
196 
WienerHorizontalTap1(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,int16_t ** const wiener_buffer)197 inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
198                                  const ptrdiff_t width, const int height,
199                                  int16_t** const wiener_buffer) {
200   for (int y = height; y != 0; --y) {
201     ptrdiff_t x = 0;
202     do {
203       const __m256i s = LoadUnaligned32(src + x);
204       const __m256i s0 = _mm256_unpacklo_epi8(s, _mm256_setzero_si256());
205       const __m256i s1 = _mm256_unpackhi_epi8(s, _mm256_setzero_si256());
206       __m256i d[2];
207       d[0] = _mm256_slli_epi16(s0, 4);
208       d[1] = _mm256_slli_epi16(s1, 4);
209       StoreAligned64(*wiener_buffer + x, d);
210       x += 32;
211     } while (x < width);
212     src += src_stride;
213     *wiener_buffer += width;
214   }
215 }
216 
WienerVertical7(const __m256i a[2],const __m256i filter[2])217 inline __m256i WienerVertical7(const __m256i a[2], const __m256i filter[2]) {
218   const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsVertical - 1));
219   const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]);
220   const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]);
221   const __m256i sum0 = _mm256_add_epi32(round, madd0);
222   const __m256i sum1 = _mm256_add_epi32(sum0, madd1);
223   return _mm256_srai_epi32(sum1, kInterRoundBitsVertical);
224 }
225 
WienerVertical5(const __m256i a[2],const __m256i filter[2])226 inline __m256i WienerVertical5(const __m256i a[2], const __m256i filter[2]) {
227   const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]);
228   const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]);
229   const __m256i sum = _mm256_add_epi32(madd0, madd1);
230   return _mm256_srai_epi32(sum, kInterRoundBitsVertical);
231 }
232 
WienerVertical3(const __m256i a,const __m256i filter)233 inline __m256i WienerVertical3(const __m256i a, const __m256i filter) {
234   const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsVertical - 1));
235   const __m256i madd = _mm256_madd_epi16(a, filter);
236   const __m256i sum = _mm256_add_epi32(round, madd);
237   return _mm256_srai_epi32(sum, kInterRoundBitsVertical);
238 }
239 
WienerVerticalFilter7(const __m256i a[7],const __m256i filter[2])240 inline __m256i WienerVerticalFilter7(const __m256i a[7],
241                                      const __m256i filter[2]) {
242   __m256i b[2];
243   const __m256i a06 = _mm256_add_epi16(a[0], a[6]);
244   const __m256i a15 = _mm256_add_epi16(a[1], a[5]);
245   const __m256i a24 = _mm256_add_epi16(a[2], a[4]);
246   b[0] = _mm256_unpacklo_epi16(a06, a15);
247   b[1] = _mm256_unpacklo_epi16(a24, a[3]);
248   const __m256i sum0 = WienerVertical7(b, filter);
249   b[0] = _mm256_unpackhi_epi16(a06, a15);
250   b[1] = _mm256_unpackhi_epi16(a24, a[3]);
251   const __m256i sum1 = WienerVertical7(b, filter);
252   return _mm256_packs_epi32(sum0, sum1);
253 }
254 
WienerVerticalFilter5(const __m256i a[5],const __m256i filter[2])255 inline __m256i WienerVerticalFilter5(const __m256i a[5],
256                                      const __m256i filter[2]) {
257   const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1));
258   __m256i b[2];
259   const __m256i a04 = _mm256_add_epi16(a[0], a[4]);
260   const __m256i a13 = _mm256_add_epi16(a[1], a[3]);
261   b[0] = _mm256_unpacklo_epi16(a04, a13);
262   b[1] = _mm256_unpacklo_epi16(a[2], round);
263   const __m256i sum0 = WienerVertical5(b, filter);
264   b[0] = _mm256_unpackhi_epi16(a04, a13);
265   b[1] = _mm256_unpackhi_epi16(a[2], round);
266   const __m256i sum1 = WienerVertical5(b, filter);
267   return _mm256_packs_epi32(sum0, sum1);
268 }
269 
WienerVerticalFilter3(const __m256i a[3],const __m256i filter)270 inline __m256i WienerVerticalFilter3(const __m256i a[3], const __m256i filter) {
271   __m256i b;
272   const __m256i a02 = _mm256_add_epi16(a[0], a[2]);
273   b = _mm256_unpacklo_epi16(a02, a[1]);
274   const __m256i sum0 = WienerVertical3(b, filter);
275   b = _mm256_unpackhi_epi16(a02, a[1]);
276   const __m256i sum1 = WienerVertical3(b, filter);
277   return _mm256_packs_epi32(sum0, sum1);
278 }
279 
WienerVerticalTap7Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i a[7])280 inline __m256i WienerVerticalTap7Kernel(const int16_t* wiener_buffer,
281                                         const ptrdiff_t wiener_stride,
282                                         const __m256i filter[2], __m256i a[7]) {
283   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
284   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
285   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
286   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
287   a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride);
288   a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride);
289   a[6] = LoadAligned32(wiener_buffer + 6 * wiener_stride);
290   return WienerVerticalFilter7(a, filter);
291 }
292 
WienerVerticalTap5Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i a[5])293 inline __m256i WienerVerticalTap5Kernel(const int16_t* wiener_buffer,
294                                         const ptrdiff_t wiener_stride,
295                                         const __m256i filter[2], __m256i a[5]) {
296   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
297   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
298   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
299   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
300   a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride);
301   return WienerVerticalFilter5(a, filter);
302 }
303 
WienerVerticalTap3Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter,__m256i a[3])304 inline __m256i WienerVerticalTap3Kernel(const int16_t* wiener_buffer,
305                                         const ptrdiff_t wiener_stride,
306                                         const __m256i filter, __m256i a[3]) {
307   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
308   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
309   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
310   return WienerVerticalFilter3(a, filter);
311 }
312 
WienerVerticalTap7Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i d[2])313 inline void WienerVerticalTap7Kernel2(const int16_t* wiener_buffer,
314                                       const ptrdiff_t wiener_stride,
315                                       const __m256i filter[2], __m256i d[2]) {
316   __m256i a[8];
317   d[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
318   a[7] = LoadAligned32(wiener_buffer + 7 * wiener_stride);
319   d[1] = WienerVerticalFilter7(a + 1, filter);
320 }
321 
WienerVerticalTap5Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i d[2])322 inline void WienerVerticalTap5Kernel2(const int16_t* wiener_buffer,
323                                       const ptrdiff_t wiener_stride,
324                                       const __m256i filter[2], __m256i d[2]) {
325   __m256i a[6];
326   d[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
327   a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride);
328   d[1] = WienerVerticalFilter5(a + 1, filter);
329 }
330 
WienerVerticalTap3Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter,__m256i d[2])331 inline void WienerVerticalTap3Kernel2(const int16_t* wiener_buffer,
332                                       const ptrdiff_t wiener_stride,
333                                       const __m256i filter, __m256i d[2]) {
334   __m256i a[4];
335   d[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
336   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
337   d[1] = WienerVerticalFilter3(a + 1, filter);
338 }
339 
WienerVerticalTap7(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[4],uint8_t * dst,const ptrdiff_t dst_stride)340 inline void WienerVerticalTap7(const int16_t* wiener_buffer,
341                                const ptrdiff_t width, const int height,
342                                const int16_t coefficients[4], uint8_t* dst,
343                                const ptrdiff_t dst_stride) {
344   const __m256i c = _mm256_broadcastq_epi64(LoadLo8(coefficients));
345   __m256i filter[2];
346   filter[0] = _mm256_shuffle_epi32(c, 0x0);
347   filter[1] = _mm256_shuffle_epi32(c, 0x55);
348   for (int y = height >> 1; y > 0; --y) {
349     ptrdiff_t x = 0;
350     do {
351       __m256i d[2][2];
352       WienerVerticalTap7Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
353       WienerVerticalTap7Kernel2(wiener_buffer + x + 16, width, filter, d[1]);
354       StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0]));
355       StoreUnaligned32(dst + dst_stride + x,
356                        _mm256_packus_epi16(d[0][1], d[1][1]));
357       x += 32;
358     } while (x < width);
359     dst += 2 * dst_stride;
360     wiener_buffer += 2 * width;
361   }
362 
363   if ((height & 1) != 0) {
364     ptrdiff_t x = 0;
365     do {
366       __m256i a[7];
367       const __m256i d0 =
368           WienerVerticalTap7Kernel(wiener_buffer + x + 0, width, filter, a);
369       const __m256i d1 =
370           WienerVerticalTap7Kernel(wiener_buffer + x + 16, width, filter, a);
371       StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
372       x += 32;
373     } while (x < width);
374   }
375 }
376 
WienerVerticalTap5(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[3],uint8_t * dst,const ptrdiff_t dst_stride)377 inline void WienerVerticalTap5(const int16_t* wiener_buffer,
378                                const ptrdiff_t width, const int height,
379                                const int16_t coefficients[3], uint8_t* dst,
380                                const ptrdiff_t dst_stride) {
381   const __m256i c = _mm256_broadcastd_epi32(Load4(coefficients));
382   __m256i filter[2];
383   filter[0] = _mm256_shuffle_epi32(c, 0);
384   filter[1] =
385       _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[2]));
386   for (int y = height >> 1; y > 0; --y) {
387     ptrdiff_t x = 0;
388     do {
389       __m256i d[2][2];
390       WienerVerticalTap5Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
391       WienerVerticalTap5Kernel2(wiener_buffer + x + 16, width, filter, d[1]);
392       StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0]));
393       StoreUnaligned32(dst + dst_stride + x,
394                        _mm256_packus_epi16(d[0][1], d[1][1]));
395       x += 32;
396     } while (x < width);
397     dst += 2 * dst_stride;
398     wiener_buffer += 2 * width;
399   }
400 
401   if ((height & 1) != 0) {
402     ptrdiff_t x = 0;
403     do {
404       __m256i a[5];
405       const __m256i d0 =
406           WienerVerticalTap5Kernel(wiener_buffer + x + 0, width, filter, a);
407       const __m256i d1 =
408           WienerVerticalTap5Kernel(wiener_buffer + x + 16, width, filter, a);
409       StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
410       x += 32;
411     } while (x < width);
412   }
413 }
414 
WienerVerticalTap3(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[2],uint8_t * dst,const ptrdiff_t dst_stride)415 inline void WienerVerticalTap3(const int16_t* wiener_buffer,
416                                const ptrdiff_t width, const int height,
417                                const int16_t coefficients[2], uint8_t* dst,
418                                const ptrdiff_t dst_stride) {
419   const __m256i filter =
420       _mm256_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients));
421   for (int y = height >> 1; y > 0; --y) {
422     ptrdiff_t x = 0;
423     do {
424       __m256i d[2][2];
425       WienerVerticalTap3Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
426       WienerVerticalTap3Kernel2(wiener_buffer + x + 16, width, filter, d[1]);
427       StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0]));
428       StoreUnaligned32(dst + dst_stride + x,
429                        _mm256_packus_epi16(d[0][1], d[1][1]));
430       x += 32;
431     } while (x < width);
432     dst += 2 * dst_stride;
433     wiener_buffer += 2 * width;
434   }
435 
436   if ((height & 1) != 0) {
437     ptrdiff_t x = 0;
438     do {
439       __m256i a[3];
440       const __m256i d0 =
441           WienerVerticalTap3Kernel(wiener_buffer + x + 0, width, filter, a);
442       const __m256i d1 =
443           WienerVerticalTap3Kernel(wiener_buffer + x + 16, width, filter, a);
444       StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
445       x += 32;
446     } while (x < width);
447   }
448 }
449 
WienerVerticalTap1Kernel(const int16_t * const wiener_buffer,uint8_t * const dst)450 inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
451                                      uint8_t* const dst) {
452   const __m256i a0 = LoadAligned32(wiener_buffer + 0);
453   const __m256i a1 = LoadAligned32(wiener_buffer + 16);
454   const __m256i b0 = _mm256_add_epi16(a0, _mm256_set1_epi16(8));
455   const __m256i b1 = _mm256_add_epi16(a1, _mm256_set1_epi16(8));
456   const __m256i c0 = _mm256_srai_epi16(b0, 4);
457   const __m256i c1 = _mm256_srai_epi16(b1, 4);
458   const __m256i d = _mm256_packus_epi16(c0, c1);
459   StoreUnaligned32(dst, d);
460 }
461 
WienerVerticalTap1(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,uint8_t * dst,const ptrdiff_t dst_stride)462 inline void WienerVerticalTap1(const int16_t* wiener_buffer,
463                                const ptrdiff_t width, const int height,
464                                uint8_t* dst, const ptrdiff_t dst_stride) {
465   for (int y = height >> 1; y > 0; --y) {
466     ptrdiff_t x = 0;
467     do {
468       WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
469       WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x);
470       x += 32;
471     } while (x < width);
472     dst += 2 * dst_stride;
473     wiener_buffer += 2 * width;
474   }
475 
476   if ((height & 1) != 0) {
477     ptrdiff_t x = 0;
478     do {
479       WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
480       x += 32;
481     } while (x < width);
482   }
483 }
484 
WienerFilter_AVX2(const RestorationUnitInfo & restoration_info,const void * const source,const ptrdiff_t stride,const void * const top_border,const ptrdiff_t top_border_stride,const void * const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * const restoration_buffer,void * const dest)485 void WienerFilter_AVX2(
486     const RestorationUnitInfo& restoration_info, const void* const source,
487     const ptrdiff_t stride, const void* const top_border,
488     const ptrdiff_t top_border_stride, const void* const bottom_border,
489     const ptrdiff_t bottom_border_stride, const int width, const int height,
490     RestorationBuffer* const restoration_buffer, void* const dest) {
491   const int16_t* const number_leading_zero_coefficients =
492       restoration_info.wiener_info.number_leading_zero_coefficients;
493   const int number_rows_to_skip = std::max(
494       static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
495       1);
496   const ptrdiff_t wiener_stride = Align(width, 32);
497   int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
498   // The values are saturated to 13 bits before storing.
499   int16_t* wiener_buffer_horizontal =
500       wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
501 
502   // horizontal filtering.
503   // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
504   const int height_horizontal =
505       height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
506   const int height_extra = (height_horizontal - height) >> 1;
507   assert(height_extra <= 2);
508   const auto* const src = static_cast<const uint8_t*>(source);
509   const auto* const top = static_cast<const uint8_t*>(top_border);
510   const auto* const bottom = static_cast<const uint8_t*>(bottom_border);
511   const __m128i c =
512       LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]);
513   // In order to keep the horizontal pass intermediate values within 16 bits we
514   // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
515   __m128i c_horizontal =
516       _mm_sub_epi16(c, _mm_setr_epi16(0, 0, 0, 128, 0, 0, 0, 0));
517   c_horizontal = _mm_packs_epi16(c_horizontal, c_horizontal);
518   const __m256i coefficients_horizontal = _mm256_broadcastd_epi32(c_horizontal);
519   if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
520     WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3,
521                          top_border_stride, wiener_stride, height_extra,
522                          coefficients_horizontal, &wiener_buffer_horizontal);
523     WienerHorizontalTap7(src - 3, stride, wiener_stride, height,
524                          coefficients_horizontal, &wiener_buffer_horizontal);
525     WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride,
526                          height_extra, coefficients_horizontal,
527                          &wiener_buffer_horizontal);
528   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
529     WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2,
530                          top_border_stride, wiener_stride, height_extra,
531                          coefficients_horizontal, &wiener_buffer_horizontal);
532     WienerHorizontalTap5(src - 2, stride, wiener_stride, height,
533                          coefficients_horizontal, &wiener_buffer_horizontal);
534     WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride,
535                          height_extra, coefficients_horizontal,
536                          &wiener_buffer_horizontal);
537   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
538     // The maximum over-reads happen here.
539     WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1,
540                          top_border_stride, wiener_stride, height_extra,
541                          coefficients_horizontal, &wiener_buffer_horizontal);
542     WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
543                          coefficients_horizontal, &wiener_buffer_horizontal);
544     WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride,
545                          height_extra, coefficients_horizontal,
546                          &wiener_buffer_horizontal);
547   } else {
548     assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
549     WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride,
550                          top_border_stride, wiener_stride, height_extra,
551                          &wiener_buffer_horizontal);
552     WienerHorizontalTap1(src, stride, wiener_stride, height,
553                          &wiener_buffer_horizontal);
554     WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride,
555                          height_extra, &wiener_buffer_horizontal);
556   }
557 
558   // vertical filtering.
559   // Over-writes up to 15 values.
560   const int16_t* const filter_vertical =
561       restoration_info.wiener_info.filter[WienerInfo::kVertical];
562   auto* dst = static_cast<uint8_t*>(dest);
563   if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
564     // Because the top row of |source| is a duplicate of the second row, and the
565     // bottom row of |source| is a duplicate of its above row, we can duplicate
566     // the top and bottom row of |wiener_buffer| accordingly.
567     memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
568            sizeof(*wiener_buffer_horizontal) * wiener_stride);
569     memcpy(restoration_buffer->wiener_buffer,
570            restoration_buffer->wiener_buffer + wiener_stride,
571            sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
572     WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
573                        filter_vertical, dst, stride);
574   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
575     WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
576                        height, filter_vertical + 1, dst, stride);
577   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
578     WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
579                        wiener_stride, height, filter_vertical + 2, dst, stride);
580   } else {
581     assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
582     WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
583                        wiener_stride, height, dst, stride);
584   }
585 }
586 
587 //------------------------------------------------------------------------------
588 // SGR
589 
590 constexpr int kSumOffset = 24;
591 
592 // SIMD overreads the number of bytes in SIMD registers - (width % 16) - 2 *
593 // padding pixels, where padding is 3 for Pass 1 and 2 for Pass 2. The number of
594 // bytes in SIMD registers is 16 for SSE4.1 and 32 for AVX2.
595 constexpr int kOverreadInBytesPass1_128 = 10;
596 constexpr int kOverreadInBytesPass2_128 = 12;
597 constexpr int kOverreadInBytesPass1_256 = kOverreadInBytesPass1_128 + 16;
598 constexpr int kOverreadInBytesPass2_256 = kOverreadInBytesPass2_128 + 16;
599 
LoadAligned16x2U16(const uint16_t * const src[2],const ptrdiff_t x,__m128i dst[2])600 inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x,
601                                __m128i dst[2]) {
602   dst[0] = LoadAligned16(src[0] + x);
603   dst[1] = LoadAligned16(src[1] + x);
604 }
605 
LoadAligned32x2U16(const uint16_t * const src[2],const ptrdiff_t x,__m256i dst[2])606 inline void LoadAligned32x2U16(const uint16_t* const src[2], const ptrdiff_t x,
607                                __m256i dst[2]) {
608   dst[0] = LoadAligned32(src[0] + x);
609   dst[1] = LoadAligned32(src[1] + x);
610 }
611 
LoadAligned32x2U16Msan(const uint16_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[2])612 inline void LoadAligned32x2U16Msan(const uint16_t* const src[2],
613                                    const ptrdiff_t x, const ptrdiff_t border,
614                                    __m256i dst[2]) {
615   dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border));
616   dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border));
617 }
618 
LoadAligned16x3U16(const uint16_t * const src[3],const ptrdiff_t x,__m128i dst[3])619 inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x,
620                                __m128i dst[3]) {
621   dst[0] = LoadAligned16(src[0] + x);
622   dst[1] = LoadAligned16(src[1] + x);
623   dst[2] = LoadAligned16(src[2] + x);
624 }
625 
LoadAligned32x3U16(const uint16_t * const src[3],const ptrdiff_t x,__m256i dst[3])626 inline void LoadAligned32x3U16(const uint16_t* const src[3], const ptrdiff_t x,
627                                __m256i dst[3]) {
628   dst[0] = LoadAligned32(src[0] + x);
629   dst[1] = LoadAligned32(src[1] + x);
630   dst[2] = LoadAligned32(src[2] + x);
631 }
632 
LoadAligned32x3U16Msan(const uint16_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[3])633 inline void LoadAligned32x3U16Msan(const uint16_t* const src[3],
634                                    const ptrdiff_t x, const ptrdiff_t border,
635                                    __m256i dst[3]) {
636   dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border));
637   dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border));
638   dst[2] = LoadAligned32Msan(src[2] + x, sizeof(**src) * (x + 16 - border));
639 }
640 
LoadAligned32U32(const uint32_t * const src,__m128i dst[2])641 inline void LoadAligned32U32(const uint32_t* const src, __m128i dst[2]) {
642   dst[0] = LoadAligned16(src + 0);
643   dst[1] = LoadAligned16(src + 4);
644 }
645 
LoadAligned32x2U32(const uint32_t * const src[2],const ptrdiff_t x,__m128i dst[2][2])646 inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x,
647                                __m128i dst[2][2]) {
648   LoadAligned32U32(src[0] + x, dst[0]);
649   LoadAligned32U32(src[1] + x, dst[1]);
650 }
651 
LoadAligned64x2U32(const uint32_t * const src[2],const ptrdiff_t x,__m256i dst[2][2])652 inline void LoadAligned64x2U32(const uint32_t* const src[2], const ptrdiff_t x,
653                                __m256i dst[2][2]) {
654   LoadAligned64(src[0] + x, dst[0]);
655   LoadAligned64(src[1] + x, dst[1]);
656 }
657 
LoadAligned64x2U32Msan(const uint32_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[2][2])658 inline void LoadAligned64x2U32Msan(const uint32_t* const src[2],
659                                    const ptrdiff_t x, const ptrdiff_t border,
660                                    __m256i dst[2][2]) {
661   LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]);
662   LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]);
663 }
664 
LoadAligned32x3U32(const uint32_t * const src[3],const ptrdiff_t x,__m128i dst[3][2])665 inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x,
666                                __m128i dst[3][2]) {
667   LoadAligned32U32(src[0] + x, dst[0]);
668   LoadAligned32U32(src[1] + x, dst[1]);
669   LoadAligned32U32(src[2] + x, dst[2]);
670 }
671 
LoadAligned64x3U32(const uint32_t * const src[3],const ptrdiff_t x,__m256i dst[3][2])672 inline void LoadAligned64x3U32(const uint32_t* const src[3], const ptrdiff_t x,
673                                __m256i dst[3][2]) {
674   LoadAligned64(src[0] + x, dst[0]);
675   LoadAligned64(src[1] + x, dst[1]);
676   LoadAligned64(src[2] + x, dst[2]);
677 }
678 
LoadAligned64x3U32Msan(const uint32_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[3][2])679 inline void LoadAligned64x3U32Msan(const uint32_t* const src[3],
680                                    const ptrdiff_t x, const ptrdiff_t border,
681                                    __m256i dst[3][2]) {
682   LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]);
683   LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]);
684   LoadAligned64Msan(src[2] + x, sizeof(**src) * (x + 16 - border), dst[2]);
685 }
686 
StoreAligned32U32(uint32_t * const dst,const __m128i src[2])687 inline void StoreAligned32U32(uint32_t* const dst, const __m128i src[2]) {
688   StoreAligned16(dst + 0, src[0]);
689   StoreAligned16(dst + 4, src[1]);
690 }
691 
692 // Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following
693 // functions. Some compilers may generate super inefficient code and the whole
694 // decoder could be 15% slower.
695 
VaddlLo8(const __m128i src0,const __m128i src1)696 inline __m128i VaddlLo8(const __m128i src0, const __m128i src1) {
697   const __m128i s0 = _mm_unpacklo_epi8(src0, _mm_setzero_si128());
698   const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128());
699   return _mm_add_epi16(s0, s1);
700 }
701 
VaddlLo8(const __m256i src0,const __m256i src1)702 inline __m256i VaddlLo8(const __m256i src0, const __m256i src1) {
703   const __m256i s0 = _mm256_unpacklo_epi8(src0, _mm256_setzero_si256());
704   const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256());
705   return _mm256_add_epi16(s0, s1);
706 }
707 
VaddlHi8(const __m256i src0,const __m256i src1)708 inline __m256i VaddlHi8(const __m256i src0, const __m256i src1) {
709   const __m256i s0 = _mm256_unpackhi_epi8(src0, _mm256_setzero_si256());
710   const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256());
711   return _mm256_add_epi16(s0, s1);
712 }
713 
VaddlLo16(const __m128i src0,const __m128i src1)714 inline __m128i VaddlLo16(const __m128i src0, const __m128i src1) {
715   const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
716   const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
717   return _mm_add_epi32(s0, s1);
718 }
719 
VaddlLo16(const __m256i src0,const __m256i src1)720 inline __m256i VaddlLo16(const __m256i src0, const __m256i src1) {
721   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
722   const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256());
723   return _mm256_add_epi32(s0, s1);
724 }
725 
VaddlHi16(const __m128i src0,const __m128i src1)726 inline __m128i VaddlHi16(const __m128i src0, const __m128i src1) {
727   const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
728   const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
729   return _mm_add_epi32(s0, s1);
730 }
731 
VaddlHi16(const __m256i src0,const __m256i src1)732 inline __m256i VaddlHi16(const __m256i src0, const __m256i src1) {
733   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
734   const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256());
735   return _mm256_add_epi32(s0, s1);
736 }
737 
VaddwLo8(const __m128i src0,const __m128i src1)738 inline __m128i VaddwLo8(const __m128i src0, const __m128i src1) {
739   const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128());
740   return _mm_add_epi16(src0, s1);
741 }
742 
VaddwLo8(const __m256i src0,const __m256i src1)743 inline __m256i VaddwLo8(const __m256i src0, const __m256i src1) {
744   const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256());
745   return _mm256_add_epi16(src0, s1);
746 }
747 
VaddwHi8(const __m256i src0,const __m256i src1)748 inline __m256i VaddwHi8(const __m256i src0, const __m256i src1) {
749   const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256());
750   return _mm256_add_epi16(src0, s1);
751 }
752 
VaddwLo16(const __m128i src0,const __m128i src1)753 inline __m128i VaddwLo16(const __m128i src0, const __m128i src1) {
754   const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
755   return _mm_add_epi32(src0, s1);
756 }
757 
VaddwLo16(const __m256i src0,const __m256i src1)758 inline __m256i VaddwLo16(const __m256i src0, const __m256i src1) {
759   const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256());
760   return _mm256_add_epi32(src0, s1);
761 }
762 
VaddwHi16(const __m128i src0,const __m128i src1)763 inline __m128i VaddwHi16(const __m128i src0, const __m128i src1) {
764   const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
765   return _mm_add_epi32(src0, s1);
766 }
767 
VaddwHi16(const __m256i src0,const __m256i src1)768 inline __m256i VaddwHi16(const __m256i src0, const __m256i src1) {
769   const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256());
770   return _mm256_add_epi32(src0, s1);
771 }
772 
VmullNLo8(const __m256i src0,const int src1)773 inline __m256i VmullNLo8(const __m256i src0, const int src1) {
774   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
775   return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1));
776 }
777 
VmullNHi8(const __m256i src0,const int src1)778 inline __m256i VmullNHi8(const __m256i src0, const int src1) {
779   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
780   return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1));
781 }
782 
VmullLo16(const __m128i src0,const __m128i src1)783 inline __m128i VmullLo16(const __m128i src0, const __m128i src1) {
784   const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
785   const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
786   return _mm_madd_epi16(s0, s1);
787 }
788 
VmullLo16(const __m256i src0,const __m256i src1)789 inline __m256i VmullLo16(const __m256i src0, const __m256i src1) {
790   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
791   const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256());
792   return _mm256_madd_epi16(s0, s1);
793 }
794 
VmullHi16(const __m128i src0,const __m128i src1)795 inline __m128i VmullHi16(const __m128i src0, const __m128i src1) {
796   const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
797   const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
798   return _mm_madd_epi16(s0, s1);
799 }
800 
VmullHi16(const __m256i src0,const __m256i src1)801 inline __m256i VmullHi16(const __m256i src0, const __m256i src1) {
802   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
803   const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256());
804   return _mm256_madd_epi16(s0, s1);
805 }
806 
VrshrS32(const __m256i src0,const int src1)807 inline __m256i VrshrS32(const __m256i src0, const int src1) {
808   const __m256i sum =
809       _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1)));
810   return _mm256_srai_epi32(sum, src1);
811 }
812 
VrshrU32(const __m128i src0,const int src1)813 inline __m128i VrshrU32(const __m128i src0, const int src1) {
814   const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1)));
815   return _mm_srli_epi32(sum, src1);
816 }
817 
VrshrU32(const __m256i src0,const int src1)818 inline __m256i VrshrU32(const __m256i src0, const int src1) {
819   const __m256i sum =
820       _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1)));
821   return _mm256_srli_epi32(sum, src1);
822 }
823 
SquareLo8(const __m128i src)824 inline __m128i SquareLo8(const __m128i src) {
825   const __m128i s = _mm_unpacklo_epi8(src, _mm_setzero_si128());
826   return _mm_mullo_epi16(s, s);
827 }
828 
SquareLo8(const __m256i src)829 inline __m256i SquareLo8(const __m256i src) {
830   const __m256i s = _mm256_unpacklo_epi8(src, _mm256_setzero_si256());
831   return _mm256_mullo_epi16(s, s);
832 }
833 
SquareHi8(const __m128i src)834 inline __m128i SquareHi8(const __m128i src) {
835   const __m128i s = _mm_unpackhi_epi8(src, _mm_setzero_si128());
836   return _mm_mullo_epi16(s, s);
837 }
838 
SquareHi8(const __m256i src)839 inline __m256i SquareHi8(const __m256i src) {
840   const __m256i s = _mm256_unpackhi_epi8(src, _mm256_setzero_si256());
841   return _mm256_mullo_epi16(s, s);
842 }
843 
Prepare3Lo8(const __m128i src,__m128i dst[3])844 inline void Prepare3Lo8(const __m128i src, __m128i dst[3]) {
845   dst[0] = src;
846   dst[1] = _mm_srli_si128(src, 1);
847   dst[2] = _mm_srli_si128(src, 2);
848 }
849 
Prepare3_8(const __m256i src[2],__m256i dst[3])850 inline void Prepare3_8(const __m256i src[2], __m256i dst[3]) {
851   dst[0] = _mm256_alignr_epi8(src[1], src[0], 0);
852   dst[1] = _mm256_alignr_epi8(src[1], src[0], 1);
853   dst[2] = _mm256_alignr_epi8(src[1], src[0], 2);
854 }
855 
Prepare3_16(const __m128i src[2],__m128i dst[3])856 inline void Prepare3_16(const __m128i src[2], __m128i dst[3]) {
857   dst[0] = src[0];
858   dst[1] = _mm_alignr_epi8(src[1], src[0], 2);
859   dst[2] = _mm_alignr_epi8(src[1], src[0], 4);
860 }
861 
Prepare3_16(const __m256i src[2],__m256i dst[3])862 inline void Prepare3_16(const __m256i src[2], __m256i dst[3]) {
863   dst[0] = src[0];
864   dst[1] = _mm256_alignr_epi8(src[1], src[0], 2);
865   dst[2] = _mm256_alignr_epi8(src[1], src[0], 4);
866 }
867 
Prepare5Lo8(const __m128i src,__m128i dst[5])868 inline void Prepare5Lo8(const __m128i src, __m128i dst[5]) {
869   dst[0] = src;
870   dst[1] = _mm_srli_si128(src, 1);
871   dst[2] = _mm_srli_si128(src, 2);
872   dst[3] = _mm_srli_si128(src, 3);
873   dst[4] = _mm_srli_si128(src, 4);
874 }
875 
Prepare5_16(const __m128i src[2],__m128i dst[5])876 inline void Prepare5_16(const __m128i src[2], __m128i dst[5]) {
877   Prepare3_16(src, dst);
878   dst[3] = _mm_alignr_epi8(src[1], src[0], 6);
879   dst[4] = _mm_alignr_epi8(src[1], src[0], 8);
880 }
881 
Prepare5_16(const __m256i src[2],__m256i dst[5])882 inline void Prepare5_16(const __m256i src[2], __m256i dst[5]) {
883   Prepare3_16(src, dst);
884   dst[3] = _mm256_alignr_epi8(src[1], src[0], 6);
885   dst[4] = _mm256_alignr_epi8(src[1], src[0], 8);
886 }
887 
Sum3_16(const __m128i src0,const __m128i src1,const __m128i src2)888 inline __m128i Sum3_16(const __m128i src0, const __m128i src1,
889                        const __m128i src2) {
890   const __m128i sum = _mm_add_epi16(src0, src1);
891   return _mm_add_epi16(sum, src2);
892 }
893 
Sum3_16(const __m256i src0,const __m256i src1,const __m256i src2)894 inline __m256i Sum3_16(const __m256i src0, const __m256i src1,
895                        const __m256i src2) {
896   const __m256i sum = _mm256_add_epi16(src0, src1);
897   return _mm256_add_epi16(sum, src2);
898 }
899 
Sum3_16(const __m128i src[3])900 inline __m128i Sum3_16(const __m128i src[3]) {
901   return Sum3_16(src[0], src[1], src[2]);
902 }
903 
Sum3_16(const __m256i src[3])904 inline __m256i Sum3_16(const __m256i src[3]) {
905   return Sum3_16(src[0], src[1], src[2]);
906 }
907 
Sum3_32(const __m128i src0,const __m128i src1,const __m128i src2)908 inline __m128i Sum3_32(const __m128i src0, const __m128i src1,
909                        const __m128i src2) {
910   const __m128i sum = _mm_add_epi32(src0, src1);
911   return _mm_add_epi32(sum, src2);
912 }
913 
Sum3_32(const __m256i src0,const __m256i src1,const __m256i src2)914 inline __m256i Sum3_32(const __m256i src0, const __m256i src1,
915                        const __m256i src2) {
916   const __m256i sum = _mm256_add_epi32(src0, src1);
917   return _mm256_add_epi32(sum, src2);
918 }
919 
Sum3_32(const __m128i src[3][2],__m128i dst[2])920 inline void Sum3_32(const __m128i src[3][2], __m128i dst[2]) {
921   dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
922   dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
923 }
924 
Sum3_32(const __m256i src[3][2],__m256i dst[2])925 inline void Sum3_32(const __m256i src[3][2], __m256i dst[2]) {
926   dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
927   dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
928 }
929 
Sum3WLo16(const __m128i src[3])930 inline __m128i Sum3WLo16(const __m128i src[3]) {
931   const __m128i sum = VaddlLo8(src[0], src[1]);
932   return VaddwLo8(sum, src[2]);
933 }
934 
Sum3WLo16(const __m256i src[3])935 inline __m256i Sum3WLo16(const __m256i src[3]) {
936   const __m256i sum = VaddlLo8(src[0], src[1]);
937   return VaddwLo8(sum, src[2]);
938 }
939 
Sum3WHi16(const __m256i src[3])940 inline __m256i Sum3WHi16(const __m256i src[3]) {
941   const __m256i sum = VaddlHi8(src[0], src[1]);
942   return VaddwHi8(sum, src[2]);
943 }
944 
Sum3WLo32(const __m128i src[3])945 inline __m128i Sum3WLo32(const __m128i src[3]) {
946   const __m128i sum = VaddlLo16(src[0], src[1]);
947   return VaddwLo16(sum, src[2]);
948 }
949 
Sum3WLo32(const __m256i src[3])950 inline __m256i Sum3WLo32(const __m256i src[3]) {
951   const __m256i sum = VaddlLo16(src[0], src[1]);
952   return VaddwLo16(sum, src[2]);
953 }
954 
Sum3WHi32(const __m128i src[3])955 inline __m128i Sum3WHi32(const __m128i src[3]) {
956   const __m128i sum = VaddlHi16(src[0], src[1]);
957   return VaddwHi16(sum, src[2]);
958 }
959 
Sum3WHi32(const __m256i src[3])960 inline __m256i Sum3WHi32(const __m256i src[3]) {
961   const __m256i sum = VaddlHi16(src[0], src[1]);
962   return VaddwHi16(sum, src[2]);
963 }
964 
Sum5_16(const __m128i src[5])965 inline __m128i Sum5_16(const __m128i src[5]) {
966   const __m128i sum01 = _mm_add_epi16(src[0], src[1]);
967   const __m128i sum23 = _mm_add_epi16(src[2], src[3]);
968   const __m128i sum = _mm_add_epi16(sum01, sum23);
969   return _mm_add_epi16(sum, src[4]);
970 }
971 
Sum5_16(const __m256i src[5])972 inline __m256i Sum5_16(const __m256i src[5]) {
973   const __m256i sum01 = _mm256_add_epi16(src[0], src[1]);
974   const __m256i sum23 = _mm256_add_epi16(src[2], src[3]);
975   const __m256i sum = _mm256_add_epi16(sum01, sum23);
976   return _mm256_add_epi16(sum, src[4]);
977 }
978 
Sum5_32(const __m128i * const src0,const __m128i * const src1,const __m128i * const src2,const __m128i * const src3,const __m128i * const src4)979 inline __m128i Sum5_32(const __m128i* const src0, const __m128i* const src1,
980                        const __m128i* const src2, const __m128i* const src3,
981                        const __m128i* const src4) {
982   const __m128i sum01 = _mm_add_epi32(*src0, *src1);
983   const __m128i sum23 = _mm_add_epi32(*src2, *src3);
984   const __m128i sum = _mm_add_epi32(sum01, sum23);
985   return _mm_add_epi32(sum, *src4);
986 }
987 
Sum5_32(const __m256i * const src0,const __m256i * const src1,const __m256i * const src2,const __m256i * const src3,const __m256i * const src4)988 inline __m256i Sum5_32(const __m256i* const src0, const __m256i* const src1,
989                        const __m256i* const src2, const __m256i* const src3,
990                        const __m256i* const src4) {
991   const __m256i sum01 = _mm256_add_epi32(*src0, *src1);
992   const __m256i sum23 = _mm256_add_epi32(*src2, *src3);
993   const __m256i sum = _mm256_add_epi32(sum01, sum23);
994   return _mm256_add_epi32(sum, *src4);
995 }
996 
Sum5_32(const __m128i src[5][2],__m128i dst[2])997 inline void Sum5_32(const __m128i src[5][2], __m128i dst[2]) {
998   dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
999   dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
1000 }
1001 
Sum5_32(const __m256i src[5][2],__m256i dst[2])1002 inline void Sum5_32(const __m256i src[5][2], __m256i dst[2]) {
1003   dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
1004   dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
1005 }
1006 
Sum5WLo16(const __m128i src[5])1007 inline __m128i Sum5WLo16(const __m128i src[5]) {
1008   const __m128i sum01 = VaddlLo8(src[0], src[1]);
1009   const __m128i sum23 = VaddlLo8(src[2], src[3]);
1010   const __m128i sum = _mm_add_epi16(sum01, sum23);
1011   return VaddwLo8(sum, src[4]);
1012 }
1013 
Sum5WLo16(const __m256i src[5])1014 inline __m256i Sum5WLo16(const __m256i src[5]) {
1015   const __m256i sum01 = VaddlLo8(src[0], src[1]);
1016   const __m256i sum23 = VaddlLo8(src[2], src[3]);
1017   const __m256i sum = _mm256_add_epi16(sum01, sum23);
1018   return VaddwLo8(sum, src[4]);
1019 }
1020 
Sum5WHi16(const __m256i src[5])1021 inline __m256i Sum5WHi16(const __m256i src[5]) {
1022   const __m256i sum01 = VaddlHi8(src[0], src[1]);
1023   const __m256i sum23 = VaddlHi8(src[2], src[3]);
1024   const __m256i sum = _mm256_add_epi16(sum01, sum23);
1025   return VaddwHi8(sum, src[4]);
1026 }
1027 
Sum3Horizontal(const __m128i src)1028 inline __m128i Sum3Horizontal(const __m128i src) {
1029   __m128i s[3];
1030   Prepare3Lo8(src, s);
1031   return Sum3WLo16(s);
1032 }
1033 
Sum3Horizontal(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,__m256i dst[2])1034 inline void Sum3Horizontal(const uint8_t* const src,
1035                            const ptrdiff_t over_read_in_bytes, __m256i dst[2]) {
1036   __m256i s[3];
1037   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
1038   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1);
1039   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2);
1040   dst[0] = Sum3WLo16(s);
1041   dst[1] = Sum3WHi16(s);
1042 }
1043 
Sum3WHorizontal(const __m128i src[2],__m128i dst[2])1044 inline void Sum3WHorizontal(const __m128i src[2], __m128i dst[2]) {
1045   __m128i s[3];
1046   Prepare3_16(src, s);
1047   dst[0] = Sum3WLo32(s);
1048   dst[1] = Sum3WHi32(s);
1049 }
1050 
Sum3WHorizontal(const __m256i src[2],__m256i dst[2])1051 inline void Sum3WHorizontal(const __m256i src[2], __m256i dst[2]) {
1052   __m256i s[3];
1053   Prepare3_16(src, s);
1054   dst[0] = Sum3WLo32(s);
1055   dst[1] = Sum3WHi32(s);
1056 }
1057 
Sum5Horizontal(const __m128i src)1058 inline __m128i Sum5Horizontal(const __m128i src) {
1059   __m128i s[5];
1060   Prepare5Lo8(src, s);
1061   return Sum5WLo16(s);
1062 }
1063 
Sum5Horizontal(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,__m256i * const dst0,__m256i * const dst1)1064 inline void Sum5Horizontal(const uint8_t* const src,
1065                            const ptrdiff_t over_read_in_bytes,
1066                            __m256i* const dst0, __m256i* const dst1) {
1067   __m256i s[5];
1068   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
1069   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1);
1070   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2);
1071   s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 3);
1072   s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 4);
1073   *dst0 = Sum5WLo16(s);
1074   *dst1 = Sum5WHi16(s);
1075 }
1076 
Sum5WHorizontal(const __m128i src[2],__m128i dst[2])1077 inline void Sum5WHorizontal(const __m128i src[2], __m128i dst[2]) {
1078   __m128i s[5];
1079   Prepare5_16(src, s);
1080   const __m128i sum01_lo = VaddlLo16(s[0], s[1]);
1081   const __m128i sum23_lo = VaddlLo16(s[2], s[3]);
1082   const __m128i sum0123_lo = _mm_add_epi32(sum01_lo, sum23_lo);
1083   dst[0] = VaddwLo16(sum0123_lo, s[4]);
1084   const __m128i sum01_hi = VaddlHi16(s[0], s[1]);
1085   const __m128i sum23_hi = VaddlHi16(s[2], s[3]);
1086   const __m128i sum0123_hi = _mm_add_epi32(sum01_hi, sum23_hi);
1087   dst[1] = VaddwHi16(sum0123_hi, s[4]);
1088 }
1089 
Sum5WHorizontal(const __m256i src[2],__m256i dst[2])1090 inline void Sum5WHorizontal(const __m256i src[2], __m256i dst[2]) {
1091   __m256i s[5];
1092   Prepare5_16(src, s);
1093   const __m256i sum01_lo = VaddlLo16(s[0], s[1]);
1094   const __m256i sum23_lo = VaddlLo16(s[2], s[3]);
1095   const __m256i sum0123_lo = _mm256_add_epi32(sum01_lo, sum23_lo);
1096   dst[0] = VaddwLo16(sum0123_lo, s[4]);
1097   const __m256i sum01_hi = VaddlHi16(s[0], s[1]);
1098   const __m256i sum23_hi = VaddlHi16(s[2], s[3]);
1099   const __m256i sum0123_hi = _mm256_add_epi32(sum01_hi, sum23_hi);
1100   dst[1] = VaddwHi16(sum0123_hi, s[4]);
1101 }
1102 
SumHorizontalLo(const __m128i src[5],__m128i * const row_sq3,__m128i * const row_sq5)1103 void SumHorizontalLo(const __m128i src[5], __m128i* const row_sq3,
1104                      __m128i* const row_sq5) {
1105   const __m128i sum04 = VaddlLo16(src[0], src[4]);
1106   *row_sq3 = Sum3WLo32(src + 1);
1107   *row_sq5 = _mm_add_epi32(sum04, *row_sq3);
1108 }
1109 
SumHorizontalLo(const __m256i src[5],__m256i * const row_sq3,__m256i * const row_sq5)1110 void SumHorizontalLo(const __m256i src[5], __m256i* const row_sq3,
1111                      __m256i* const row_sq5) {
1112   const __m256i sum04 = VaddlLo16(src[0], src[4]);
1113   *row_sq3 = Sum3WLo32(src + 1);
1114   *row_sq5 = _mm256_add_epi32(sum04, *row_sq3);
1115 }
1116 
SumHorizontalHi(const __m128i src[5],__m128i * const row_sq3,__m128i * const row_sq5)1117 void SumHorizontalHi(const __m128i src[5], __m128i* const row_sq3,
1118                      __m128i* const row_sq5) {
1119   const __m128i sum04 = VaddlHi16(src[0], src[4]);
1120   *row_sq3 = Sum3WHi32(src + 1);
1121   *row_sq5 = _mm_add_epi32(sum04, *row_sq3);
1122 }
1123 
SumHorizontalHi(const __m256i src[5],__m256i * const row_sq3,__m256i * const row_sq5)1124 void SumHorizontalHi(const __m256i src[5], __m256i* const row_sq3,
1125                      __m256i* const row_sq5) {
1126   const __m256i sum04 = VaddlHi16(src[0], src[4]);
1127   *row_sq3 = Sum3WHi32(src + 1);
1128   *row_sq5 = _mm256_add_epi32(sum04, *row_sq3);
1129 }
1130 
SumHorizontalLo(const __m128i src,__m128i * const row3,__m128i * const row5)1131 void SumHorizontalLo(const __m128i src, __m128i* const row3,
1132                      __m128i* const row5) {
1133   __m128i s[5];
1134   Prepare5Lo8(src, s);
1135   const __m128i sum04 = VaddlLo8(s[0], s[4]);
1136   *row3 = Sum3WLo16(s + 1);
1137   *row5 = _mm_add_epi16(sum04, *row3);
1138 }
1139 
SumHorizontal(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,__m256i * const row3_0,__m256i * const row3_1,__m256i * const row5_0,__m256i * const row5_1)1140 inline void SumHorizontal(const uint8_t* const src,
1141                           const ptrdiff_t over_read_in_bytes,
1142                           __m256i* const row3_0, __m256i* const row3_1,
1143                           __m256i* const row5_0, __m256i* const row5_1) {
1144   __m256i s[5];
1145   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
1146   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1);
1147   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2);
1148   s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 3);
1149   s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 4);
1150   const __m256i sum04_lo = VaddlLo8(s[0], s[4]);
1151   const __m256i sum04_hi = VaddlHi8(s[0], s[4]);
1152   *row3_0 = Sum3WLo16(s + 1);
1153   *row3_1 = Sum3WHi16(s + 1);
1154   *row5_0 = _mm256_add_epi16(sum04_lo, *row3_0);
1155   *row5_1 = _mm256_add_epi16(sum04_hi, *row3_1);
1156 }
1157 
SumHorizontal(const __m128i src[2],__m128i * const row_sq3_0,__m128i * const row_sq3_1,__m128i * const row_sq5_0,__m128i * const row_sq5_1)1158 inline void SumHorizontal(const __m128i src[2], __m128i* const row_sq3_0,
1159                           __m128i* const row_sq3_1, __m128i* const row_sq5_0,
1160                           __m128i* const row_sq5_1) {
1161   __m128i s[5];
1162   Prepare5_16(src, s);
1163   SumHorizontalLo(s, row_sq3_0, row_sq5_0);
1164   SumHorizontalHi(s, row_sq3_1, row_sq5_1);
1165 }
1166 
SumHorizontal(const __m256i src[2],__m256i * const row_sq3_0,__m256i * const row_sq3_1,__m256i * const row_sq5_0,__m256i * const row_sq5_1)1167 inline void SumHorizontal(const __m256i src[2], __m256i* const row_sq3_0,
1168                           __m256i* const row_sq3_1, __m256i* const row_sq5_0,
1169                           __m256i* const row_sq5_1) {
1170   __m256i s[5];
1171   Prepare5_16(src, s);
1172   SumHorizontalLo(s, row_sq3_0, row_sq5_0);
1173   SumHorizontalHi(s, row_sq3_1, row_sq5_1);
1174 }
1175 
Sum343Lo(const __m256i ma3[3])1176 inline __m256i Sum343Lo(const __m256i ma3[3]) {
1177   const __m256i sum = Sum3WLo16(ma3);
1178   const __m256i sum3 = Sum3_16(sum, sum, sum);
1179   return VaddwLo8(sum3, ma3[1]);
1180 }
1181 
Sum343Hi(const __m256i ma3[3])1182 inline __m256i Sum343Hi(const __m256i ma3[3]) {
1183   const __m256i sum = Sum3WHi16(ma3);
1184   const __m256i sum3 = Sum3_16(sum, sum, sum);
1185   return VaddwHi8(sum3, ma3[1]);
1186 }
1187 
Sum343WLo(const __m256i src[3])1188 inline __m256i Sum343WLo(const __m256i src[3]) {
1189   const __m256i sum = Sum3WLo32(src);
1190   const __m256i sum3 = Sum3_32(sum, sum, sum);
1191   return VaddwLo16(sum3, src[1]);
1192 }
1193 
Sum343WHi(const __m256i src[3])1194 inline __m256i Sum343WHi(const __m256i src[3]) {
1195   const __m256i sum = Sum3WHi32(src);
1196   const __m256i sum3 = Sum3_32(sum, sum, sum);
1197   return VaddwHi16(sum3, src[1]);
1198 }
1199 
Sum343W(const __m256i src[2],__m256i dst[2])1200 inline void Sum343W(const __m256i src[2], __m256i dst[2]) {
1201   __m256i s[3];
1202   Prepare3_16(src, s);
1203   dst[0] = Sum343WLo(s);
1204   dst[1] = Sum343WHi(s);
1205 }
1206 
Sum565Lo(const __m256i src[3])1207 inline __m256i Sum565Lo(const __m256i src[3]) {
1208   const __m256i sum = Sum3WLo16(src);
1209   const __m256i sum4 = _mm256_slli_epi16(sum, 2);
1210   const __m256i sum5 = _mm256_add_epi16(sum4, sum);
1211   return VaddwLo8(sum5, src[1]);
1212 }
1213 
Sum565Hi(const __m256i src[3])1214 inline __m256i Sum565Hi(const __m256i src[3]) {
1215   const __m256i sum = Sum3WHi16(src);
1216   const __m256i sum4 = _mm256_slli_epi16(sum, 2);
1217   const __m256i sum5 = _mm256_add_epi16(sum4, sum);
1218   return VaddwHi8(sum5, src[1]);
1219 }
1220 
Sum565WLo(const __m256i src[3])1221 inline __m256i Sum565WLo(const __m256i src[3]) {
1222   const __m256i sum = Sum3WLo32(src);
1223   const __m256i sum4 = _mm256_slli_epi32(sum, 2);
1224   const __m256i sum5 = _mm256_add_epi32(sum4, sum);
1225   return VaddwLo16(sum5, src[1]);
1226 }
1227 
Sum565WHi(const __m256i src[3])1228 inline __m256i Sum565WHi(const __m256i src[3]) {
1229   const __m256i sum = Sum3WHi32(src);
1230   const __m256i sum4 = _mm256_slli_epi32(sum, 2);
1231   const __m256i sum5 = _mm256_add_epi32(sum4, sum);
1232   return VaddwHi16(sum5, src[1]);
1233 }
1234 
Sum565W(const __m256i src[2],__m256i dst[2])1235 inline void Sum565W(const __m256i src[2], __m256i dst[2]) {
1236   __m256i s[3];
1237   Prepare3_16(src, s);
1238   dst[0] = Sum565WLo(s);
1239   dst[1] = Sum565WHi(s);
1240 }
1241 
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sum3,uint16_t * sum5,uint32_t * square_sum3,uint32_t * square_sum5)1242 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
1243                    const ptrdiff_t width, const ptrdiff_t sum_stride,
1244                    const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5,
1245                    uint32_t* square_sum3, uint32_t* square_sum5) {
1246   int y = 2;
1247   do {
1248     const __m128i s0 =
1249         LoadUnaligned16Msan(src, kOverreadInBytesPass1_128 - width);
1250     __m128i sq_128[2], s3, s5, sq3[2], sq5[2];
1251     __m256i sq[3];
1252     sq_128[0] = SquareLo8(s0);
1253     sq_128[1] = SquareHi8(s0);
1254     SumHorizontalLo(s0, &s3, &s5);
1255     StoreAligned16(sum3, s3);
1256     StoreAligned16(sum5, s5);
1257     SumHorizontal(sq_128, &sq3[0], &sq3[1], &sq5[0], &sq5[1]);
1258     StoreAligned32U32(square_sum3, sq3);
1259     StoreAligned32U32(square_sum5, sq5);
1260     src += 8;
1261     sum3 += 8;
1262     sum5 += 8;
1263     square_sum3 += 8;
1264     square_sum5 += 8;
1265     sq[0] = SetrM128i(sq_128[1], sq_128[1]);
1266     ptrdiff_t x = sum_width;
1267     do {
1268       __m256i row3[2], row5[2], row_sq3[2], row_sq5[2];
1269       const __m256i s = LoadUnaligned32Msan(
1270           src + 8, sum_width - x + 16 + kOverreadInBytesPass1_256 - width);
1271       sq[1] = SquareLo8(s);
1272       sq[2] = SquareHi8(s);
1273       sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1274       SumHorizontal(src, sum_width - x + 8 + kOverreadInBytesPass1_256 - width,
1275                     &row3[0], &row3[1], &row5[0], &row5[1]);
1276       StoreAligned64(sum3, row3);
1277       StoreAligned64(sum5, row5);
1278       SumHorizontal(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]);
1279       StoreAligned64(square_sum3 + 0, row_sq3);
1280       StoreAligned64(square_sum5 + 0, row_sq5);
1281       SumHorizontal(sq + 1, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]);
1282       StoreAligned64(square_sum3 + 16, row_sq3);
1283       StoreAligned64(square_sum5 + 16, row_sq5);
1284       sq[0] = sq[2];
1285       src += 32;
1286       sum3 += 32;
1287       sum5 += 32;
1288       square_sum3 += 32;
1289       square_sum5 += 32;
1290       x -= 32;
1291     } while (x != 0);
1292     src += src_stride - sum_width - 8;
1293     sum3 += sum_stride - sum_width - 8;
1294     sum5 += sum_stride - sum_width - 8;
1295     square_sum3 += sum_stride - sum_width - 8;
1296     square_sum5 += sum_stride - sum_width - 8;
1297   } while (--y != 0);
1298 }
1299 
1300 template <int size>
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sums,uint32_t * square_sums)1301 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
1302                    const ptrdiff_t width, const ptrdiff_t sum_stride,
1303                    const ptrdiff_t sum_width, uint16_t* sums,
1304                    uint32_t* square_sums) {
1305   static_assert(size == 3 || size == 5, "");
1306   int kOverreadInBytes_128, kOverreadInBytes_256;
1307   if (size == 3) {
1308     kOverreadInBytes_128 = kOverreadInBytesPass2_128;
1309     kOverreadInBytes_256 = kOverreadInBytesPass2_256;
1310   } else {
1311     kOverreadInBytes_128 = kOverreadInBytesPass1_128;
1312     kOverreadInBytes_256 = kOverreadInBytesPass1_256;
1313   }
1314   int y = 2;
1315   do {
1316     const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytes_128 - width);
1317     __m128i ss, sq_128[2], sqs[2];
1318     __m256i sq[3];
1319     sq_128[0] = SquareLo8(s);
1320     sq_128[1] = SquareHi8(s);
1321     if (size == 3) {
1322       ss = Sum3Horizontal(s);
1323       Sum3WHorizontal(sq_128, sqs);
1324     } else {
1325       ss = Sum5Horizontal(s);
1326       Sum5WHorizontal(sq_128, sqs);
1327     }
1328     StoreAligned16(sums, ss);
1329     StoreAligned32U32(square_sums, sqs);
1330     src += 8;
1331     sums += 8;
1332     square_sums += 8;
1333     sq[0] = SetrM128i(sq_128[1], sq_128[1]);
1334     ptrdiff_t x = sum_width;
1335     do {
1336       __m256i row[2], row_sq[4];
1337       const __m256i s = LoadUnaligned32Msan(
1338           src + 8, sum_width - x + 16 + kOverreadInBytes_256 - width);
1339       sq[1] = SquareLo8(s);
1340       sq[2] = SquareHi8(s);
1341       sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1342       if (size == 3) {
1343         Sum3Horizontal(src, sum_width - x + 8 + kOverreadInBytes_256 - width,
1344                        row);
1345         Sum3WHorizontal(sq + 0, row_sq + 0);
1346         Sum3WHorizontal(sq + 1, row_sq + 2);
1347       } else {
1348         Sum5Horizontal(src, sum_width - x + 8 + kOverreadInBytes_256 - width,
1349                        &row[0], &row[1]);
1350         Sum5WHorizontal(sq + 0, row_sq + 0);
1351         Sum5WHorizontal(sq + 1, row_sq + 2);
1352       }
1353       StoreAligned64(sums, row);
1354       StoreAligned64(square_sums + 0, row_sq + 0);
1355       StoreAligned64(square_sums + 16, row_sq + 2);
1356       sq[0] = sq[2];
1357       src += 32;
1358       sums += 32;
1359       square_sums += 32;
1360       x -= 32;
1361     } while (x != 0);
1362     src += src_stride - sum_width - 8;
1363     sums += sum_stride - sum_width - 8;
1364     square_sums += sum_stride - sum_width - 8;
1365   } while (--y != 0);
1366 }
1367 
1368 template <int n>
CalculateMa(const __m128i sum,const __m128i sum_sq,const uint32_t scale)1369 inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq,
1370                            const uint32_t scale) {
1371   static_assert(n == 9 || n == 25, "");
1372   // a = |sum_sq|
1373   // d = |sum|
1374   // p = (a * n < d * d) ? 0 : a * n - d * d;
1375   const __m128i dxd = _mm_madd_epi16(sum, sum);
1376   // _mm_mullo_epi32() has high latency. Using shifts and additions instead.
1377   // Some compilers could do this for us but we make this explicit.
1378   // return _mm_mullo_epi32(sum_sq, _mm_set1_epi32(n));
1379   __m128i axn = _mm_add_epi32(sum_sq, _mm_slli_epi32(sum_sq, 3));
1380   if (n == 25) axn = _mm_add_epi32(axn, _mm_slli_epi32(sum_sq, 4));
1381   const __m128i sub = _mm_sub_epi32(axn, dxd);
1382   const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128());
1383   const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(scale));
1384   return VrshrU32(pxs, kSgrProjScaleBits);
1385 }
1386 
1387 template <int n>
CalculateMa(const __m128i sum,const __m128i sum_sq[2],const uint32_t scale)1388 inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2],
1389                            const uint32_t scale) {
1390   static_assert(n == 9 || n == 25, "");
1391   const __m128i sum_lo = _mm_unpacklo_epi16(sum, _mm_setzero_si128());
1392   const __m128i sum_hi = _mm_unpackhi_epi16(sum, _mm_setzero_si128());
1393   const __m128i z0 = CalculateMa<n>(sum_lo, sum_sq[0], scale);
1394   const __m128i z1 = CalculateMa<n>(sum_hi, sum_sq[1], scale);
1395   return _mm_packus_epi32(z0, z1);
1396 }
1397 
1398 template <int n>
CalculateMa(const __m256i sum,const __m256i sum_sq,const uint32_t scale)1399 inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq,
1400                            const uint32_t scale) {
1401   static_assert(n == 9 || n == 25, "");
1402   // a = |sum_sq|
1403   // d = |sum|
1404   // p = (a * n < d * d) ? 0 : a * n - d * d;
1405   const __m256i dxd = _mm256_madd_epi16(sum, sum);
1406   // _mm256_mullo_epi32() has high latency. Using shifts and additions instead.
1407   // Some compilers could do this for us but we make this explicit.
1408   // return _mm256_mullo_epi32(sum_sq, _mm256_set1_epi32(n));
1409   __m256i axn = _mm256_add_epi32(sum_sq, _mm256_slli_epi32(sum_sq, 3));
1410   if (n == 25) axn = _mm256_add_epi32(axn, _mm256_slli_epi32(sum_sq, 4));
1411   const __m256i sub = _mm256_sub_epi32(axn, dxd);
1412   const __m256i p = _mm256_max_epi32(sub, _mm256_setzero_si256());
1413   const __m256i pxs = _mm256_mullo_epi32(p, _mm256_set1_epi32(scale));
1414   return VrshrU32(pxs, kSgrProjScaleBits);
1415 }
1416 
1417 template <int n>
CalculateMa(const __m256i sum,const __m256i sum_sq[2],const uint32_t scale)1418 inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq[2],
1419                            const uint32_t scale) {
1420   static_assert(n == 9 || n == 25, "");
1421   const __m256i sum_lo = _mm256_unpacklo_epi16(sum, _mm256_setzero_si256());
1422   const __m256i sum_hi = _mm256_unpackhi_epi16(sum, _mm256_setzero_si256());
1423   const __m256i z0 = CalculateMa<n>(sum_lo, sum_sq[0], scale);
1424   const __m256i z1 = CalculateMa<n>(sum_hi, sum_sq[1], scale);
1425   return _mm256_packus_epi32(z0, z1);
1426 }
1427 
CalculateB5(const __m128i sum,const __m128i ma)1428 inline __m128i CalculateB5(const __m128i sum, const __m128i ma) {
1429   // one_over_n == 164.
1430   constexpr uint32_t one_over_n =
1431       ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
1432   // one_over_n_quarter == 41.
1433   constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
1434   static_assert(one_over_n == one_over_n_quarter << 2, "");
1435   // |ma| is in range [0, 255].
1436   const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter));
1437   const __m128i m0 = VmullLo16(m, sum);
1438   const __m128i m1 = VmullHi16(m, sum);
1439   const __m128i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2);
1440   const __m128i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2);
1441   return _mm_packus_epi32(b_lo, b_hi);
1442 }
1443 
CalculateB5(const __m256i sum,const __m256i ma)1444 inline __m256i CalculateB5(const __m256i sum, const __m256i ma) {
1445   // one_over_n == 164.
1446   constexpr uint32_t one_over_n =
1447       ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
1448   // one_over_n_quarter == 41.
1449   constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
1450   static_assert(one_over_n == one_over_n_quarter << 2, "");
1451   // |ma| is in range [0, 255].
1452   const __m256i m =
1453       _mm256_maddubs_epi16(ma, _mm256_set1_epi16(one_over_n_quarter));
1454   const __m256i m0 = VmullLo16(m, sum);
1455   const __m256i m1 = VmullHi16(m, sum);
1456   const __m256i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2);
1457   const __m256i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2);
1458   return _mm256_packus_epi32(b_lo, b_hi);
1459 }
1460 
CalculateB3(const __m128i sum,const __m128i ma)1461 inline __m128i CalculateB3(const __m128i sum, const __m128i ma) {
1462   // one_over_n == 455.
1463   constexpr uint32_t one_over_n =
1464       ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
1465   const __m128i m0 = VmullLo16(ma, sum);
1466   const __m128i m1 = VmullHi16(ma, sum);
1467   const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n));
1468   const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n));
1469   const __m128i b_lo = VrshrU32(m2, kSgrProjReciprocalBits);
1470   const __m128i b_hi = VrshrU32(m3, kSgrProjReciprocalBits);
1471   return _mm_packus_epi32(b_lo, b_hi);
1472 }
1473 
CalculateB3(const __m256i sum,const __m256i ma)1474 inline __m256i CalculateB3(const __m256i sum, const __m256i ma) {
1475   // one_over_n == 455.
1476   constexpr uint32_t one_over_n =
1477       ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
1478   const __m256i m0 = VmullLo16(ma, sum);
1479   const __m256i m1 = VmullHi16(ma, sum);
1480   const __m256i m2 = _mm256_mullo_epi32(m0, _mm256_set1_epi32(one_over_n));
1481   const __m256i m3 = _mm256_mullo_epi32(m1, _mm256_set1_epi32(one_over_n));
1482   const __m256i b_lo = VrshrU32(m2, kSgrProjReciprocalBits);
1483   const __m256i b_hi = VrshrU32(m3, kSgrProjReciprocalBits);
1484   return _mm256_packus_epi32(b_lo, b_hi);
1485 }
1486 
CalculateSumAndIndex5(const __m128i s5[5],const __m128i sq5[5][2],const uint32_t scale,__m128i * const sum,__m128i * const index)1487 inline void CalculateSumAndIndex5(const __m128i s5[5], const __m128i sq5[5][2],
1488                                   const uint32_t scale, __m128i* const sum,
1489                                   __m128i* const index) {
1490   __m128i sum_sq[2];
1491   *sum = Sum5_16(s5);
1492   Sum5_32(sq5, sum_sq);
1493   *index = CalculateMa<25>(*sum, sum_sq, scale);
1494 }
1495 
CalculateSumAndIndex5(const __m256i s5[5],const __m256i sq5[5][2],const uint32_t scale,__m256i * const sum,__m256i * const index)1496 inline void CalculateSumAndIndex5(const __m256i s5[5], const __m256i sq5[5][2],
1497                                   const uint32_t scale, __m256i* const sum,
1498                                   __m256i* const index) {
1499   __m256i sum_sq[2];
1500   *sum = Sum5_16(s5);
1501   Sum5_32(sq5, sum_sq);
1502   *index = CalculateMa<25>(*sum, sum_sq, scale);
1503 }
1504 
CalculateSumAndIndex3(const __m128i s3[3],const __m128i sq3[3][2],const uint32_t scale,__m128i * const sum,__m128i * const index)1505 inline void CalculateSumAndIndex3(const __m128i s3[3], const __m128i sq3[3][2],
1506                                   const uint32_t scale, __m128i* const sum,
1507                                   __m128i* const index) {
1508   __m128i sum_sq[2];
1509   *sum = Sum3_16(s3);
1510   Sum3_32(sq3, sum_sq);
1511   *index = CalculateMa<9>(*sum, sum_sq, scale);
1512 }
1513 
CalculateSumAndIndex3(const __m256i s3[3],const __m256i sq3[3][2],const uint32_t scale,__m256i * const sum,__m256i * const index)1514 inline void CalculateSumAndIndex3(const __m256i s3[3], const __m256i sq3[3][2],
1515                                   const uint32_t scale, __m256i* const sum,
1516                                   __m256i* const index) {
1517   __m256i sum_sq[2];
1518   *sum = Sum3_16(s3);
1519   Sum3_32(sq3, sum_sq);
1520   *index = CalculateMa<9>(*sum, sum_sq, scale);
1521 }
1522 
1523 template <int n>
LookupIntermediate(const __m128i sum,const __m128i index,__m128i * const ma,__m128i * const b)1524 inline void LookupIntermediate(const __m128i sum, const __m128i index,
1525                                __m128i* const ma, __m128i* const b) {
1526   static_assert(n == 9 || n == 25, "");
1527   const __m128i idx = _mm_packus_epi16(index, index);
1528   // Actually it's not stored and loaded. The compiler will use a 64-bit
1529   // general-purpose register to process. Faster than using _mm_extract_epi8().
1530   uint8_t temp[8];
1531   StoreLo8(temp, idx);
1532   *ma = _mm_cvtsi32_si128(kSgrMaLookup[temp[0]]);
1533   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[1]], 1);
1534   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[2]], 2);
1535   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[3]], 3);
1536   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[4]], 4);
1537   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[5]], 5);
1538   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[6]], 6);
1539   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[7]], 7);
1540   // b = ma * b * one_over_n
1541   // |ma| = [0, 255]
1542   // |sum| is a box sum with radius 1 or 2.
1543   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1544   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1545   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1546   // When radius is 2 |n| is 25. |one_over_n| is 164.
1547   // When radius is 1 |n| is 9. |one_over_n| is 455.
1548   // |kSgrProjReciprocalBits| is 12.
1549   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1550   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1551   const __m128i maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128());
1552   *b = (n == 9) ? CalculateB3(sum, maq) : CalculateB5(sum, maq);
1553 }
1554 
1555 // Repeat the first 48 elements in kSgrMaLookup with a period of 16.
1556 alignas(32) constexpr uint8_t kSgrMaLookupAvx2[96] = {
1557     255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16,
1558     255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16,
1559     15,  14,  13, 13, 12, 12, 11, 11, 10, 10, 9,  9,  9,  9,  8,  8,
1560     15,  14,  13, 13, 12, 12, 11, 11, 10, 10, 9,  9,  9,  9,  8,  8,
1561     8,   8,   7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  5,  5,
1562     8,   8,   7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  5,  5};
1563 
1564 // Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b
1565 // to get value 0 as the shuffle result. The most significiant bit 1 comes
1566 // either from the comparison instruction, or from the sign bit of the index.
ShuffleIndex(const __m256i table,const __m256i index)1567 inline __m256i ShuffleIndex(const __m256i table, const __m256i index) {
1568   __m256i mask;
1569   mask = _mm256_cmpgt_epi8(index, _mm256_set1_epi8(15));
1570   mask = _mm256_or_si256(mask, index);
1571   return _mm256_shuffle_epi8(table, mask);
1572 }
1573 
AdjustValue(const __m256i value,const __m256i index,const int threshold)1574 inline __m256i AdjustValue(const __m256i value, const __m256i index,
1575                            const int threshold) {
1576   const __m256i thresholds = _mm256_set1_epi8(threshold - 128);
1577   const __m256i offset = _mm256_cmpgt_epi8(index, thresholds);
1578   return _mm256_add_epi8(value, offset);
1579 }
1580 
1581 template <int n>
CalculateIntermediate(const __m256i sum[2],const __m256i index[2],__m256i ma[3],__m256i b[2])1582 inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2],
1583                                   __m256i ma[3], __m256i b[2]) {
1584   static_assert(n == 9 || n == 25, "");
1585   // Use table lookup to read elements whose indices are less than 48.
1586   const __m256i c0 = LoadAligned32(kSgrMaLookupAvx2 + 0 * 32);
1587   const __m256i c1 = LoadAligned32(kSgrMaLookupAvx2 + 1 * 32);
1588   const __m256i c2 = LoadAligned32(kSgrMaLookupAvx2 + 2 * 32);
1589   const __m256i indices = _mm256_packus_epi16(index[0], index[1]);
1590   __m256i idx, mas;
1591   // Clip idx to 127 to apply signed comparison instructions.
1592   idx = _mm256_min_epu8(indices, _mm256_set1_epi8(127));
1593   // All elements whose indices are less than 48 are set to 0.
1594   // Get shuffle results for indices in range [0, 15].
1595   mas = ShuffleIndex(c0, idx);
1596   // Get shuffle results for indices in range [16, 31].
1597   // Subtract 16 to utilize the sign bit of the index.
1598   idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16));
1599   const __m256i res1 = ShuffleIndex(c1, idx);
1600   // Use OR instruction to combine shuffle results together.
1601   mas = _mm256_or_si256(mas, res1);
1602   // Get shuffle results for indices in range [32, 47].
1603   // Subtract 16 to utilize the sign bit of the index.
1604   idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16));
1605   const __m256i res2 = ShuffleIndex(c2, idx);
1606   mas = _mm256_or_si256(mas, res2);
1607 
1608   // For elements whose indices are larger than 47, since they seldom change
1609   // values with the increase of the index, we use comparison and arithmetic
1610   // operations to calculate their values.
1611   // Add -128 to apply signed comparison instructions.
1612   idx = _mm256_add_epi8(indices, _mm256_set1_epi8(-128));
1613   // Elements whose indices are larger than 47 (with value 0) are set to 5.
1614   mas = _mm256_max_epu8(mas, _mm256_set1_epi8(5));
1615   mas = AdjustValue(mas, idx, 55);   // 55 is the last index which value is 5.
1616   mas = AdjustValue(mas, idx, 72);   // 72 is the last index which value is 4.
1617   mas = AdjustValue(mas, idx, 101);  // 101 is the last index which value is 3.
1618   mas = AdjustValue(mas, idx, 169);  // 169 is the last index which value is 2.
1619   mas = AdjustValue(mas, idx, 254);  // 254 is the last index which value is 1.
1620 
1621   ma[2] = _mm256_permute4x64_epi64(mas, 0x93);     // 32-39 8-15 16-23 24-31
1622   ma[0] = _mm256_blend_epi32(ma[0], ma[2], 0xfc);  //  0-7  8-15 16-23 24-31
1623   ma[1] = _mm256_permute2x128_si256(ma[0], ma[2], 0x21);
1624 
1625   // b = ma * b * one_over_n
1626   // |ma| = [0, 255]
1627   // |sum| is a box sum with radius 1 or 2.
1628   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1629   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1630   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1631   // When radius is 2 |n| is 25. |one_over_n| is 164.
1632   // When radius is 1 |n| is 9. |one_over_n| is 455.
1633   // |kSgrProjReciprocalBits| is 12.
1634   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1635   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1636   const __m256i maq0 = _mm256_unpackhi_epi8(ma[0], _mm256_setzero_si256());
1637   const __m256i maq1 = _mm256_unpacklo_epi8(ma[1], _mm256_setzero_si256());
1638   if (n == 9) {
1639     b[0] = CalculateB3(sum[0], maq0);
1640     b[1] = CalculateB3(sum[1], maq1);
1641   } else {
1642     b[0] = CalculateB5(sum[0], maq0);
1643     b[1] = CalculateB5(sum[1], maq1);
1644   }
1645 }
1646 
CalculateIntermediate5(const __m128i s5[5],const __m128i sq5[5][2],const uint32_t scale,__m128i * const ma,__m128i * const b)1647 inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2],
1648                                    const uint32_t scale, __m128i* const ma,
1649                                    __m128i* const b) {
1650   __m128i sum, index;
1651   CalculateSumAndIndex5(s5, sq5, scale, &sum, &index);
1652   LookupIntermediate<25>(sum, index, ma, b);
1653 }
1654 
CalculateIntermediate3(const __m128i s3[3],const __m128i sq3[3][2],const uint32_t scale,__m128i * const ma,__m128i * const b)1655 inline void CalculateIntermediate3(const __m128i s3[3], const __m128i sq3[3][2],
1656                                    const uint32_t scale, __m128i* const ma,
1657                                    __m128i* const b) {
1658   __m128i sum, index;
1659   CalculateSumAndIndex3(s3, sq3, scale, &sum, &index);
1660   LookupIntermediate<9>(sum, index, ma, b);
1661 }
1662 
Store343_444(const __m256i b3[2],const ptrdiff_t x,__m256i sum_b343[2],__m256i sum_b444[2],uint32_t * const b343,uint32_t * const b444)1663 inline void Store343_444(const __m256i b3[2], const ptrdiff_t x,
1664                          __m256i sum_b343[2], __m256i sum_b444[2],
1665                          uint32_t* const b343, uint32_t* const b444) {
1666   __m256i b[3], sum_b111[2];
1667   Prepare3_16(b3, b);
1668   sum_b111[0] = Sum3WLo32(b);
1669   sum_b111[1] = Sum3WHi32(b);
1670   sum_b444[0] = _mm256_slli_epi32(sum_b111[0], 2);
1671   sum_b444[1] = _mm256_slli_epi32(sum_b111[1], 2);
1672   StoreAligned64(b444 + x, sum_b444);
1673   sum_b343[0] = _mm256_sub_epi32(sum_b444[0], sum_b111[0]);
1674   sum_b343[1] = _mm256_sub_epi32(sum_b444[1], sum_b111[1]);
1675   sum_b343[0] = VaddwLo16(sum_b343[0], b[1]);
1676   sum_b343[1] = VaddwHi16(sum_b343[1], b[1]);
1677   StoreAligned64(b343 + x, sum_b343);
1678 }
1679 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i * const sum_ma444,__m256i sum_b343[2],__m256i sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1680 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1681                            const ptrdiff_t x, __m256i* const sum_ma343,
1682                            __m256i* const sum_ma444, __m256i sum_b343[2],
1683                            __m256i sum_b444[2], uint16_t* const ma343,
1684                            uint16_t* const ma444, uint32_t* const b343,
1685                            uint32_t* const b444) {
1686   const __m256i sum_ma111 = Sum3WLo16(ma3);
1687   *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2);
1688   StoreAligned32(ma444 + x, *sum_ma444);
1689   const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111);
1690   *sum_ma343 = VaddwLo8(sum333, ma3[1]);
1691   StoreAligned32(ma343 + x, *sum_ma343);
1692   Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
1693 }
1694 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i * const sum_ma444,__m256i sum_b343[2],__m256i sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1695 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1696                            const ptrdiff_t x, __m256i* const sum_ma343,
1697                            __m256i* const sum_ma444, __m256i sum_b343[2],
1698                            __m256i sum_b444[2], uint16_t* const ma343,
1699                            uint16_t* const ma444, uint32_t* const b343,
1700                            uint32_t* const b444) {
1701   const __m256i sum_ma111 = Sum3WHi16(ma3);
1702   *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2);
1703   StoreAligned32(ma444 + x, *sum_ma444);
1704   const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111);
1705   *sum_ma343 = VaddwHi8(sum333, ma3[1]);
1706   StoreAligned32(ma343 + x, *sum_ma343);
1707   Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
1708 }
1709 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1710 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1711                            const ptrdiff_t x, __m256i* const sum_ma343,
1712                            __m256i sum_b343[2], uint16_t* const ma343,
1713                            uint16_t* const ma444, uint32_t* const b343,
1714                            uint32_t* const b444) {
1715   __m256i sum_ma444, sum_b444[2];
1716   Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1717                  ma444, b343, b444);
1718 }
1719 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1720 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1721                            const ptrdiff_t x, __m256i* const sum_ma343,
1722                            __m256i sum_b343[2], uint16_t* const ma343,
1723                            uint16_t* const ma444, uint32_t* const b343,
1724                            uint32_t* const b444) {
1725   __m256i sum_ma444, sum_b444[2];
1726   Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1727                  ma444, b343, b444);
1728 }
1729 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1730 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1731                            const ptrdiff_t x, uint16_t* const ma343,
1732                            uint16_t* const ma444, uint32_t* const b343,
1733                            uint32_t* const b444) {
1734   __m256i sum_ma343, sum_b343[2];
1735   Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1736 }
1737 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1738 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1739                            const ptrdiff_t x, uint16_t* const ma343,
1740                            uint16_t* const ma444, uint32_t* const b343,
1741                            uint32_t* const b444) {
1742   __m256i sum_ma343, sum_b343[2];
1743   Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1744 }
1745 
BoxFilterPreProcess5Lo(const __m128i s[2][3],const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],__m128i sq[2][2],__m128i * const ma,__m128i * const b)1746 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo(
1747     const __m128i s[2][3], const uint32_t scale, uint16_t* const sum5[5],
1748     uint32_t* const square_sum5[5], __m128i sq[2][2], __m128i* const ma,
1749     __m128i* const b) {
1750   __m128i s5[2][5], sq5[5][2];
1751   sq[0][1] = SquareHi8(s[0][0]);
1752   sq[1][1] = SquareHi8(s[1][0]);
1753   s5[0][3] = Sum5Horizontal(s[0][0]);
1754   StoreAligned16(sum5[3], s5[0][3]);
1755   s5[0][4] = Sum5Horizontal(s[1][0]);
1756   StoreAligned16(sum5[4], s5[0][4]);
1757   Sum5WHorizontal(sq[0], sq5[3]);
1758   StoreAligned32U32(square_sum5[3], sq5[3]);
1759   Sum5WHorizontal(sq[1], sq5[4]);
1760   StoreAligned32U32(square_sum5[4], sq5[4]);
1761   LoadAligned16x3U16(sum5, 0, s5[0]);
1762   LoadAligned32x3U32(square_sum5, 0, sq5);
1763   CalculateIntermediate5(s5[0], sq5, scale, ma, b);
1764 }
1765 
BoxFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],__m256i sq[2][3],__m256i ma[3],__m256i b[3])1766 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
1767     const uint8_t* const src0, const uint8_t* const src1,
1768     const ptrdiff_t over_read_in_bytes, const ptrdiff_t sum_width,
1769     const ptrdiff_t x, const uint32_t scale, uint16_t* const sum5[5],
1770     uint32_t* const square_sum5[5], __m256i sq[2][3], __m256i ma[3],
1771     __m256i b[3]) {
1772   const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8);
1773   const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8);
1774   __m256i s5[2][5], sq5[5][2], sum[2], index[2];
1775   sq[0][1] = SquareLo8(s0);
1776   sq[0][2] = SquareHi8(s0);
1777   sq[1][1] = SquareLo8(s1);
1778   sq[1][2] = SquareHi8(s1);
1779   sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21);
1780   sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21);
1781   Sum5Horizontal(src0, over_read_in_bytes, &s5[0][3], &s5[1][3]);
1782   Sum5Horizontal(src1, over_read_in_bytes, &s5[0][4], &s5[1][4]);
1783   StoreAligned32(sum5[3] + x + 0, s5[0][3]);
1784   StoreAligned32(sum5[3] + x + 16, s5[1][3]);
1785   StoreAligned32(sum5[4] + x + 0, s5[0][4]);
1786   StoreAligned32(sum5[4] + x + 16, s5[1][4]);
1787   Sum5WHorizontal(sq[0], sq5[3]);
1788   StoreAligned64(square_sum5[3] + x, sq5[3]);
1789   Sum5WHorizontal(sq[1], sq5[4]);
1790   StoreAligned64(square_sum5[4] + x, sq5[4]);
1791   LoadAligned32x3U16(sum5, x, s5[0]);
1792   LoadAligned64x3U32(square_sum5, x, sq5);
1793   CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]);
1794 
1795   Sum5WHorizontal(sq[0] + 1, sq5[3]);
1796   StoreAligned64(square_sum5[3] + x + 16, sq5[3]);
1797   Sum5WHorizontal(sq[1] + 1, sq5[4]);
1798   StoreAligned64(square_sum5[4] + x + 16, sq5[4]);
1799   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1800   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1801   CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]);
1802   CalculateIntermediate<25>(sum, index, ma, b + 1);
1803   b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21);
1804 }
1805 
BoxFilterPreProcess5LastRowLo(const __m128i s,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],__m128i sq[2],__m128i * const ma,__m128i * const b)1806 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo(
1807     const __m128i s, const uint32_t scale, const uint16_t* const sum5[5],
1808     const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma,
1809     __m128i* const b) {
1810   __m128i s5[5], sq5[5][2];
1811   sq[1] = SquareHi8(s);
1812   s5[3] = s5[4] = Sum5Horizontal(s);
1813   Sum5WHorizontal(sq, sq5[3]);
1814   sq5[4][0] = sq5[3][0];
1815   sq5[4][1] = sq5[3][1];
1816   LoadAligned16x3U16(sum5, 0, s5);
1817   LoadAligned32x3U32(square_sum5, 0, sq5);
1818   CalculateIntermediate5(s5, sq5, scale, ma, b);
1819 }
1820 
BoxFilterPreProcess5LastRow(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],__m256i sq[3],__m256i ma[3],__m256i b[3])1821 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
1822     const uint8_t* const src, const ptrdiff_t over_read_in_bytes,
1823     const ptrdiff_t sum_width, const ptrdiff_t x, const uint32_t scale,
1824     const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
1825     __m256i sq[3], __m256i ma[3], __m256i b[3]) {
1826   const __m256i s = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8);
1827   __m256i s5[2][5], sq5[5][2], sum[2], index[2];
1828   sq[1] = SquareLo8(s);
1829   sq[2] = SquareHi8(s);
1830   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1831   Sum5Horizontal(src, over_read_in_bytes, &s5[0][3], &s5[1][3]);
1832   s5[0][4] = s5[0][3];
1833   s5[1][4] = s5[1][3];
1834   Sum5WHorizontal(sq, sq5[3]);
1835   sq5[4][0] = sq5[3][0];
1836   sq5[4][1] = sq5[3][1];
1837   LoadAligned32x3U16(sum5, x, s5[0]);
1838   LoadAligned64x3U32(square_sum5, x, sq5);
1839   CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]);
1840 
1841   Sum5WHorizontal(sq + 1, sq5[3]);
1842   sq5[4][0] = sq5[3][0];
1843   sq5[4][1] = sq5[3][1];
1844   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1845   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1846   CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]);
1847   CalculateIntermediate<25>(sum, index, ma, b + 1);
1848   b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21);
1849 }
1850 
BoxFilterPreProcess3Lo(const __m128i s,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],__m128i sq[2],__m128i * const ma,__m128i * const b)1851 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo(
1852     const __m128i s, const uint32_t scale, uint16_t* const sum3[3],
1853     uint32_t* const square_sum3[3], __m128i sq[2], __m128i* const ma,
1854     __m128i* const b) {
1855   __m128i s3[3], sq3[3][2];
1856   sq[1] = SquareHi8(s);
1857   s3[2] = Sum3Horizontal(s);
1858   StoreAligned16(sum3[2], s3[2]);
1859   Sum3WHorizontal(sq, sq3[2]);
1860   StoreAligned32U32(square_sum3[2], sq3[2]);
1861   LoadAligned16x2U16(sum3, 0, s3);
1862   LoadAligned32x2U32(square_sum3, 0, sq3);
1863   CalculateIntermediate3(s3, sq3, scale, ma, b);
1864 }
1865 
BoxFilterPreProcess3(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t x,const ptrdiff_t sum_width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],__m256i sq[3],__m256i ma[3],__m256i b[3])1866 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
1867     const uint8_t* const src, const ptrdiff_t over_read_in_bytes,
1868     const ptrdiff_t x, const ptrdiff_t sum_width, const uint32_t scale,
1869     uint16_t* const sum3[3], uint32_t* const square_sum3[3], __m256i sq[3],
1870     __m256i ma[3], __m256i b[3]) {
1871   const __m256i s = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8);
1872   __m256i s3[4], sq3[3][2], sum[2], index[2];
1873   sq[1] = SquareLo8(s);
1874   sq[2] = SquareHi8(s);
1875   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1876   Sum3Horizontal(src, over_read_in_bytes, s3 + 2);
1877   StoreAligned64(sum3[2] + x, s3 + 2);
1878   Sum3WHorizontal(sq + 0, sq3[2]);
1879   StoreAligned64(square_sum3[2] + x, sq3[2]);
1880   LoadAligned32x2U16(sum3, x, s3);
1881   LoadAligned64x2U32(square_sum3, x, sq3);
1882   CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]);
1883 
1884   Sum3WHorizontal(sq + 1, sq3[2]);
1885   StoreAligned64(square_sum3[2] + x + 16, sq3[2]);
1886   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3 + 1);
1887   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
1888   CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]);
1889   CalculateIntermediate<9>(sum, index, ma, b + 1);
1890   b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21);
1891 }
1892 
BoxFilterPreProcessLo(const __m128i s[2],const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],__m128i sq[2][2],__m128i ma3[2],__m128i b3[2],__m128i * const ma5,__m128i * const b5)1893 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo(
1894     const __m128i s[2], const uint16_t scales[2], uint16_t* const sum3[4],
1895     uint16_t* const sum5[5], uint32_t* const square_sum3[4],
1896     uint32_t* const square_sum5[5], __m128i sq[2][2], __m128i ma3[2],
1897     __m128i b3[2], __m128i* const ma5, __m128i* const b5) {
1898   __m128i s3[4], s5[5], sq3[4][2], sq5[5][2];
1899   sq[0][1] = SquareHi8(s[0]);
1900   sq[1][1] = SquareHi8(s[1]);
1901   SumHorizontalLo(s[0], &s3[2], &s5[3]);
1902   SumHorizontalLo(s[1], &s3[3], &s5[4]);
1903   StoreAligned16(sum3[2], s3[2]);
1904   StoreAligned16(sum3[3], s3[3]);
1905   StoreAligned16(sum5[3], s5[3]);
1906   StoreAligned16(sum5[4], s5[4]);
1907   SumHorizontal(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1908   StoreAligned32U32(square_sum3[2], sq3[2]);
1909   StoreAligned32U32(square_sum5[3], sq5[3]);
1910   SumHorizontal(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1911   StoreAligned32U32(square_sum3[3], sq3[3]);
1912   StoreAligned32U32(square_sum5[4], sq5[4]);
1913   LoadAligned16x2U16(sum3, 0, s3);
1914   LoadAligned32x2U32(square_sum3, 0, sq3);
1915   LoadAligned16x3U16(sum5, 0, s5);
1916   LoadAligned32x3U32(square_sum5, 0, sq5);
1917   // Note: in the SSE4_1 version, CalculateIntermediate() is called
1918   // to replace the slow LookupIntermediate() when calculating 16 intermediate
1919   // data points. However, the AVX2 compiler generates even slower code. So we
1920   // keep using CalculateIntermediate3().
1921   CalculateIntermediate3(s3 + 0, sq3 + 0, scales[1], &ma3[0], &b3[0]);
1922   CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], &ma3[1], &b3[1]);
1923   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
1924 }
1925 
BoxFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t over_read_in_bytes,const ptrdiff_t x,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,__m256i sq[2][3],__m256i ma3[2][3],__m256i b3[2][5],__m256i ma5[3],__m256i b5[5])1926 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
1927     const uint8_t* const src0, const uint8_t* const src1,
1928     const ptrdiff_t over_read_in_bytes, const ptrdiff_t x,
1929     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1930     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1931     const ptrdiff_t sum_width, __m256i sq[2][3], __m256i ma3[2][3],
1932     __m256i b3[2][5], __m256i ma5[3], __m256i b5[5]) {
1933   const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8);
1934   const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8);
1935   __m256i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2][2], index_3[2][2],
1936       sum_5[2], index_5[2];
1937   sq[0][1] = SquareLo8(s0);
1938   sq[0][2] = SquareHi8(s0);
1939   sq[1][1] = SquareLo8(s1);
1940   sq[1][2] = SquareHi8(s1);
1941   sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21);
1942   sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21);
1943   SumHorizontal(src0, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3],
1944                 &s5[1][3]);
1945   SumHorizontal(src1, over_read_in_bytes, &s3[0][3], &s3[1][3], &s5[0][4],
1946                 &s5[1][4]);
1947   StoreAligned32(sum3[2] + x + 0, s3[0][2]);
1948   StoreAligned32(sum3[2] + x + 16, s3[1][2]);
1949   StoreAligned32(sum3[3] + x + 0, s3[0][3]);
1950   StoreAligned32(sum3[3] + x + 16, s3[1][3]);
1951   StoreAligned32(sum5[3] + x + 0, s5[0][3]);
1952   StoreAligned32(sum5[3] + x + 16, s5[1][3]);
1953   StoreAligned32(sum5[4] + x + 0, s5[0][4]);
1954   StoreAligned32(sum5[4] + x + 16, s5[1][4]);
1955   SumHorizontal(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1956   SumHorizontal(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1957   StoreAligned64(square_sum3[2] + x, sq3[2]);
1958   StoreAligned64(square_sum5[3] + x, sq5[3]);
1959   StoreAligned64(square_sum3[3] + x, sq3[3]);
1960   StoreAligned64(square_sum5[4] + x, sq5[4]);
1961   LoadAligned32x2U16(sum3, x, s3[0]);
1962   LoadAligned64x2U32(square_sum3, x, sq3);
1963   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0][0], &index_3[0][0]);
1964   CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum_3[1][0],
1965                         &index_3[1][0]);
1966   LoadAligned32x3U16(sum5, x, s5[0]);
1967   LoadAligned64x3U32(square_sum5, x, sq5);
1968   CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]);
1969 
1970   SumHorizontal(sq[0] + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1971   SumHorizontal(sq[1] + 1, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1972   StoreAligned64(square_sum3[2] + x + 16, sq3[2]);
1973   StoreAligned64(square_sum5[3] + x + 16, sq5[3]);
1974   StoreAligned64(square_sum3[3] + x + 16, sq3[3]);
1975   StoreAligned64(square_sum5[4] + x + 16, sq5[4]);
1976   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]);
1977   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
1978   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[0][1], &index_3[0][1]);
1979   CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum_3[1][1],
1980                         &index_3[1][1]);
1981   CalculateIntermediate<9>(sum_3[0], index_3[0], ma3[0], b3[0] + 1);
1982   CalculateIntermediate<9>(sum_3[1], index_3[1], ma3[1], b3[1] + 1);
1983   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1984   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1985   CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]);
1986   CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1);
1987   b3[0][0] = _mm256_permute2x128_si256(b3[0][0], b3[0][2], 0x21);
1988   b3[1][0] = _mm256_permute2x128_si256(b3[1][0], b3[1][2], 0x21);
1989   b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21);
1990 }
1991 
BoxFilterPreProcessLastRowLo(const __m128i s,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],__m128i sq[2],__m128i * const ma3,__m128i * const ma5,__m128i * const b3,__m128i * const b5)1992 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo(
1993     const __m128i s, const uint16_t scales[2], const uint16_t* const sum3[4],
1994     const uint16_t* const sum5[5], const uint32_t* const square_sum3[4],
1995     const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma3,
1996     __m128i* const ma5, __m128i* const b3, __m128i* const b5) {
1997   __m128i s3[3], s5[5], sq3[3][2], sq5[5][2];
1998   sq[1] = SquareHi8(s);
1999   SumHorizontalLo(s, &s3[2], &s5[3]);
2000   SumHorizontal(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2001   LoadAligned16x3U16(sum5, 0, s5);
2002   s5[4] = s5[3];
2003   LoadAligned32x3U32(square_sum5, 0, sq5);
2004   sq5[4][0] = sq5[3][0];
2005   sq5[4][1] = sq5[3][1];
2006   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
2007   LoadAligned16x2U16(sum3, 0, s3);
2008   LoadAligned32x2U32(square_sum3, 0, sq3);
2009   CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
2010 }
2011 
BoxFilterPreProcessLastRow(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],__m256i sq[6],__m256i ma3[2],__m256i ma5[2],__m256i b3[5],__m256i b5[5])2012 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
2013     const uint8_t* const src, const ptrdiff_t over_read_in_bytes,
2014     const ptrdiff_t sum_width, const ptrdiff_t x, const uint16_t scales[2],
2015     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
2016     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
2017     __m256i sq[6], __m256i ma3[2], __m256i ma5[2], __m256i b3[5],
2018     __m256i b5[5]) {
2019   const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8);
2020   __m256i s3[2][3], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2], index_3[2],
2021       sum_5[2], index_5[2];
2022   sq[1] = SquareLo8(s0);
2023   sq[2] = SquareHi8(s0);
2024   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
2025   SumHorizontal(src, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3],
2026                 &s5[1][3]);
2027   SumHorizontal(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2028   LoadAligned32x2U16(sum3, x, s3[0]);
2029   LoadAligned64x2U32(square_sum3, x, sq3);
2030   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0], &index_3[0]);
2031   LoadAligned32x3U16(sum5, x, s5[0]);
2032   s5[0][4] = s5[0][3];
2033   LoadAligned64x3U32(square_sum5, x, sq5);
2034   sq5[4][0] = sq5[3][0];
2035   sq5[4][1] = sq5[3][1];
2036   CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]);
2037 
2038   SumHorizontal(sq + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2039   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]);
2040   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
2041   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[1], &index_3[1]);
2042   CalculateIntermediate<9>(sum_3, index_3, ma3, b3 + 1);
2043   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
2044   s5[1][4] = s5[1][3];
2045   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
2046   sq5[4][0] = sq5[3][0];
2047   sq5[4][1] = sq5[3][1];
2048   CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]);
2049   CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1);
2050   b3[0] = _mm256_permute2x128_si256(b3[0], b3[2], 0x21);
2051   b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21);
2052 }
2053 
BoxSumFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * ma565,uint32_t * b565)2054 inline void BoxSumFilterPreProcess5(const uint8_t* const src0,
2055                                     const uint8_t* const src1, const int width,
2056                                     const uint32_t scale,
2057                                     uint16_t* const sum5[5],
2058                                     uint32_t* const square_sum5[5],
2059                                     const ptrdiff_t sum_width, uint16_t* ma565,
2060                                     uint32_t* b565) {
2061   __m128i ma0, b0, s[2][3], sq_128[2][2];
2062   __m256i mas[3], sq[2][3], bs[3];
2063   s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2064   s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2065   sq_128[0][0] = SquareLo8(s[0][0]);
2066   sq_128[1][0] = SquareLo8(s[1][0]);
2067   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, &b0);
2068   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2069   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2070   mas[0] = SetrM128i(ma0, ma0);
2071   bs[0] = SetrM128i(b0, b0);
2072 
2073   int x = 0;
2074   do {
2075     __m256i ma5[3], ma[2], b[4];
2076     BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8,
2077                          x + 8 + kOverreadInBytesPass1_256 - width, sum_width,
2078                          x + 8, scale, sum5, square_sum5, sq, mas, bs);
2079     Prepare3_8(mas, ma5);
2080     ma[0] = Sum565Lo(ma5);
2081     ma[1] = Sum565Hi(ma5);
2082     StoreAligned64(ma565, ma);
2083     Sum565W(bs + 0, b + 0);
2084     Sum565W(bs + 1, b + 2);
2085     StoreAligned64(b565, b + 0);
2086     StoreAligned64(b565 + 16, b + 2);
2087     sq[0][0] = sq[0][2];
2088     sq[1][0] = sq[1][2];
2089     mas[0] = mas[2];
2090     bs[0] = bs[2];
2091     ma565 += 32;
2092     b565 += 32;
2093     x += 32;
2094   } while (x < width);
2095 }
2096 
2097 template <bool calculate444>
BoxSumFilterPreProcess3(const uint8_t * const src,const int width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],const ptrdiff_t sum_width,uint16_t * ma343,uint16_t * ma444,uint32_t * b343,uint32_t * b444)2098 LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
2099     const uint8_t* const src, const int width, const uint32_t scale,
2100     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
2101     const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343,
2102     uint32_t* b444) {
2103   const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytesPass2_128 - width);
2104   __m128i ma0, sq_128[2], b0;
2105   __m256i mas[3], sq[3], bs[3];
2106   sq_128[0] = SquareLo8(s);
2107   BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq_128, &ma0, &b0);
2108   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2109   mas[0] = SetrM128i(ma0, ma0);
2110   bs[0] = SetrM128i(b0, b0);
2111 
2112   int x = 0;
2113   do {
2114     __m256i ma3[3];
2115     BoxFilterPreProcess3(src + x + 8, x + 8 + kOverreadInBytesPass2_256 - width,
2116                          x + 8, sum_width, scale, sum3, square_sum3, sq, mas,
2117                          bs);
2118     Prepare3_8(mas, ma3);
2119     if (calculate444) {  // NOLINT(readability-simplify-boolean-expr)
2120       Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444);
2121       Store343_444Hi(ma3, bs + 1, 16, ma343, ma444, b343, b444);
2122       ma444 += 32;
2123       b444 += 32;
2124     } else {
2125       __m256i ma[2], b[4];
2126       ma[0] = Sum343Lo(ma3);
2127       ma[1] = Sum343Hi(ma3);
2128       StoreAligned64(ma343, ma);
2129       Sum343W(bs + 0, b + 0);
2130       Sum343W(bs + 1, b + 2);
2131       StoreAligned64(b343 + 0, b + 0);
2132       StoreAligned64(b343 + 16, b + 2);
2133     }
2134     sq[0] = sq[2];
2135     mas[0] = mas[2];
2136     bs[0] = bs[2];
2137     ma343 += 32;
2138     b343 += 32;
2139     x += 32;
2140   } while (x < width);
2141 }
2142 
BoxSumFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444,uint16_t * ma565,uint32_t * const b343[4],uint32_t * const b444,uint32_t * b565)2143 inline void BoxSumFilterPreProcess(
2144     const uint8_t* const src0, const uint8_t* const src1, const int width,
2145     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
2146     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2147     const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444,
2148     uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444,
2149     uint32_t* b565) {
2150   __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0;
2151   __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5];
2152   s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2153   s[1] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2154   sq_128[0][0] = SquareLo8(s[0]);
2155   sq_128[1][0] = SquareLo8(s[1]);
2156   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128,
2157                         ma3_128, b3_128, &ma5_0, &b5_0);
2158   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2159   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2160   ma3[0][0] = SetrM128i(ma3_128[0], ma3_128[0]);
2161   ma3[1][0] = SetrM128i(ma3_128[1], ma3_128[1]);
2162   ma5[0] = SetrM128i(ma5_0, ma5_0);
2163   b3[0][0] = SetrM128i(b3_128[0], b3_128[0]);
2164   b3[1][0] = SetrM128i(b3_128[1], b3_128[1]);
2165   b5[0] = SetrM128i(b5_0, b5_0);
2166 
2167   int x = 0;
2168   do {
2169     __m256i ma[2], b[4], ma3x[3], ma5x[3];
2170     BoxFilterPreProcess(src0 + x + 8, src1 + x + 8,
2171                         x + 8 + kOverreadInBytesPass1_256 - width, x + 8,
2172                         scales, sum3, sum5, square_sum3, square_sum5, sum_width,
2173                         sq, ma3, b3, ma5, b5);
2174     Prepare3_8(ma3[0], ma3x);
2175     ma[0] = Sum343Lo(ma3x);
2176     ma[1] = Sum343Hi(ma3x);
2177     StoreAligned64(ma343[0] + x, ma);
2178     Sum343W(b3[0], b);
2179     StoreAligned64(b343[0] + x, b);
2180     Sum565W(b5, b);
2181     StoreAligned64(b565, b);
2182     Prepare3_8(ma3[1], ma3x);
2183     Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444);
2184     Store343_444Hi(ma3x, b3[1] + 1, x + 16, ma343[1], ma444, b343[1], b444);
2185     Prepare3_8(ma5, ma5x);
2186     ma[0] = Sum565Lo(ma5x);
2187     ma[1] = Sum565Hi(ma5x);
2188     StoreAligned64(ma565, ma);
2189     Sum343W(b3[0] + 1, b);
2190     StoreAligned64(b343[0] + x + 16, b);
2191     Sum565W(b5 + 1, b);
2192     StoreAligned64(b565 + 16, b);
2193     sq[0][0] = sq[0][2];
2194     sq[1][0] = sq[1][2];
2195     ma3[0][0] = ma3[0][2];
2196     ma3[1][0] = ma3[1][2];
2197     ma5[0] = ma5[2];
2198     b3[0][0] = b3[0][2];
2199     b3[1][0] = b3[1][2];
2200     b5[0] = b5[2];
2201     ma565 += 32;
2202     b565 += 32;
2203     x += 32;
2204   } while (x < width);
2205 }
2206 
2207 template <int shift>
FilterOutput(const __m256i ma_x_src,const __m256i b)2208 inline __m256i FilterOutput(const __m256i ma_x_src, const __m256i b) {
2209   // ma: 255 * 32 = 8160 (13 bits)
2210   // b: 65088 * 32 = 2082816 (21 bits)
2211   // v: b - ma * 255 (22 bits)
2212   const __m256i v = _mm256_sub_epi32(b, ma_x_src);
2213   // kSgrProjSgrBits = 8
2214   // kSgrProjRestoreBits = 4
2215   // shift = 4 or 5
2216   // v >> 8 or 9 (13 bits)
2217   return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
2218 }
2219 
2220 template <int shift>
CalculateFilteredOutput(const __m256i src,const __m256i ma,const __m256i b[2])2221 inline __m256i CalculateFilteredOutput(const __m256i src, const __m256i ma,
2222                                        const __m256i b[2]) {
2223   const __m256i ma_x_src_lo = VmullLo16(ma, src);
2224   const __m256i ma_x_src_hi = VmullHi16(ma, src);
2225   const __m256i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]);
2226   const __m256i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]);
2227   return _mm256_packs_epi32(dst_lo, dst_hi);  // 13 bits
2228 }
2229 
CalculateFilteredOutputPass1(const __m256i src,const __m256i ma[2],const __m256i b[2][2])2230 inline __m256i CalculateFilteredOutputPass1(const __m256i src,
2231                                             const __m256i ma[2],
2232                                             const __m256i b[2][2]) {
2233   const __m256i ma_sum = _mm256_add_epi16(ma[0], ma[1]);
2234   __m256i b_sum[2];
2235   b_sum[0] = _mm256_add_epi32(b[0][0], b[1][0]);
2236   b_sum[1] = _mm256_add_epi32(b[0][1], b[1][1]);
2237   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
2238 }
2239 
CalculateFilteredOutputPass2(const __m256i src,const __m256i ma[3],const __m256i b[3][2])2240 inline __m256i CalculateFilteredOutputPass2(const __m256i src,
2241                                             const __m256i ma[3],
2242                                             const __m256i b[3][2]) {
2243   const __m256i ma_sum = Sum3_16(ma);
2244   __m256i b_sum[2];
2245   Sum3_32(b, b_sum);
2246   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
2247 }
2248 
SelfGuidedFinal(const __m256i src,const __m256i v[2])2249 inline __m256i SelfGuidedFinal(const __m256i src, const __m256i v[2]) {
2250   const __m256i v_lo =
2251       VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
2252   const __m256i v_hi =
2253       VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
2254   const __m256i vv = _mm256_packs_epi32(v_lo, v_hi);
2255   return _mm256_add_epi16(src, vv);
2256 }
2257 
SelfGuidedDoubleMultiplier(const __m256i src,const __m256i filter[2],const int w0,const int w2)2258 inline __m256i SelfGuidedDoubleMultiplier(const __m256i src,
2259                                           const __m256i filter[2], const int w0,
2260                                           const int w2) {
2261   __m256i v[2];
2262   const __m256i w0_w2 =
2263       _mm256_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0));
2264   const __m256i f_lo = _mm256_unpacklo_epi16(filter[0], filter[1]);
2265   const __m256i f_hi = _mm256_unpackhi_epi16(filter[0], filter[1]);
2266   v[0] = _mm256_madd_epi16(w0_w2, f_lo);
2267   v[1] = _mm256_madd_epi16(w0_w2, f_hi);
2268   return SelfGuidedFinal(src, v);
2269 }
2270 
SelfGuidedSingleMultiplier(const __m256i src,const __m256i filter,const int w0)2271 inline __m256i SelfGuidedSingleMultiplier(const __m256i src,
2272                                           const __m256i filter, const int w0) {
2273   // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
2274   __m256i v[2];
2275   v[0] = VmullNLo8(filter, w0);
2276   v[1] = VmullNHi8(filter, w0);
2277   return SelfGuidedFinal(src, v);
2278 }
2279 
BoxFilterPass1(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const ma565[2],uint32_t * const b565[2],uint8_t * const dst)2280 LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
2281     const uint8_t* const src, const uint8_t* const src0,
2282     const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
2283     uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width,
2284     const uint32_t scale, const int16_t w0, uint16_t* const ma565[2],
2285     uint32_t* const b565[2], uint8_t* const dst) {
2286   __m128i ma0, b0, s[2][3], sq_128[2][2];
2287   __m256i mas[3], sq[2][3], bs[3];
2288   s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2289   s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2290   sq_128[0][0] = SquareLo8(s[0][0]);
2291   sq_128[1][0] = SquareLo8(s[1][0]);
2292   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, &b0);
2293   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2294   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2295   mas[0] = SetrM128i(ma0, ma0);
2296   bs[0] = SetrM128i(b0, b0);
2297 
2298   int x = 0;
2299   do {
2300     __m256i ma[3], ma5[3], b[2][2][2];
2301     BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8,
2302                          x + 8 + kOverreadInBytesPass1_256 - width, sum_width,
2303                          x + 8, scale, sum5, square_sum5, sq, mas, bs);
2304     Prepare3_8(mas, ma5);
2305     ma[1] = Sum565Lo(ma5);
2306     ma[2] = Sum565Hi(ma5);
2307     StoreAligned64(ma565[1] + x, ma + 1);
2308     Sum565W(bs + 0, b[0][1]);
2309     Sum565W(bs + 1, b[1][1]);
2310     StoreAligned64(b565[1] + x + 0, b[0][1]);
2311     StoreAligned64(b565[1] + x + 16, b[1][1]);
2312     const __m256i sr0 = LoadUnaligned32(src + x);
2313     const __m256i sr1 = LoadUnaligned32(src + stride + x);
2314     const __m256i sr0_lo = _mm256_unpacklo_epi8(sr0, _mm256_setzero_si256());
2315     const __m256i sr1_lo = _mm256_unpacklo_epi8(sr1, _mm256_setzero_si256());
2316     ma[0] = LoadAligned32(ma565[0] + x);
2317     LoadAligned64(b565[0] + x, b[0][0]);
2318     const __m256i p00 = CalculateFilteredOutputPass1(sr0_lo, ma, b[0]);
2319     const __m256i p01 = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[0][1]);
2320     const __m256i d00 = SelfGuidedSingleMultiplier(sr0_lo, p00, w0);
2321     const __m256i d10 = SelfGuidedSingleMultiplier(sr1_lo, p01, w0);
2322     const __m256i sr0_hi = _mm256_unpackhi_epi8(sr0, _mm256_setzero_si256());
2323     const __m256i sr1_hi = _mm256_unpackhi_epi8(sr1, _mm256_setzero_si256());
2324     ma[1] = LoadAligned32(ma565[0] + x + 16);
2325     LoadAligned64(b565[0] + x + 16, b[1][0]);
2326     const __m256i p10 = CalculateFilteredOutputPass1(sr0_hi, ma + 1, b[1]);
2327     const __m256i p11 = CalculateFilteredOutput<4>(sr1_hi, ma[2], b[1][1]);
2328     const __m256i d01 = SelfGuidedSingleMultiplier(sr0_hi, p10, w0);
2329     const __m256i d11 = SelfGuidedSingleMultiplier(sr1_hi, p11, w0);
2330     StoreUnaligned32(dst + x, _mm256_packus_epi16(d00, d01));
2331     StoreUnaligned32(dst + stride + x, _mm256_packus_epi16(d10, d11));
2332     sq[0][0] = sq[0][2];
2333     sq[1][0] = sq[1][2];
2334     mas[0] = mas[2];
2335     bs[0] = bs[2];
2336     x += 32;
2337   } while (x < width);
2338 }
2339 
BoxFilterPass1LastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565,uint8_t * const dst)2340 inline void BoxFilterPass1LastRow(
2341     const uint8_t* const src, const uint8_t* const src0, const int width,
2342     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2343     uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565,
2344     uint32_t* b565, uint8_t* const dst) {
2345   const __m128i s0 =
2346       LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2347   __m128i ma0, b0, sq_128[2];
2348   __m256i mas[3], sq[3], bs[3];
2349   sq_128[0] = SquareLo8(s0);
2350   BoxFilterPreProcess5LastRowLo(s0, scale, sum5, square_sum5, sq_128, &ma0,
2351                                 &b0);
2352   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2353   mas[0] = SetrM128i(ma0, ma0);
2354   bs[0] = SetrM128i(b0, b0);
2355 
2356   int x = 0;
2357   do {
2358     __m256i ma[3], ma5[3], b[2][2];
2359     BoxFilterPreProcess5LastRow(
2360         src0 + x + 8, x + 8 + kOverreadInBytesPass1_256 - width, sum_width,
2361         x + 8, scale, sum5, square_sum5, sq, mas, bs);
2362     Prepare3_8(mas, ma5);
2363     ma[1] = Sum565Lo(ma5);
2364     ma[2] = Sum565Hi(ma5);
2365     Sum565W(bs + 0, b[1]);
2366     const __m256i sr = LoadUnaligned32(src + x);
2367     const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256());
2368     const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256());
2369     ma[0] = LoadAligned32(ma565);
2370     LoadAligned64(b565 + 0, b[0]);
2371     const __m256i p0 = CalculateFilteredOutputPass1(sr_lo, ma, b);
2372     ma[1] = LoadAligned32(ma565 + 16);
2373     LoadAligned64(b565 + 16, b[0]);
2374     Sum565W(bs + 1, b[1]);
2375     const __m256i p1 = CalculateFilteredOutputPass1(sr_hi, ma + 1, b);
2376     const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0);
2377     const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0);
2378     StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
2379     sq[0] = sq[2];
2380     mas[0] = mas[2];
2381     bs[0] = bs[2];
2382     ma565 += 32;
2383     b565 += 32;
2384     x += 32;
2385   } while (x < width);
2386 }
2387 
BoxFilterPass2(const uint8_t * const src,const uint8_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * const ma343[3],uint16_t * const ma444[2],uint32_t * const b343[3],uint32_t * const b444[2],uint8_t * const dst)2388 LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
2389     const uint8_t* const src, const uint8_t* const src0, const int width,
2390     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2391     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
2392     uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3],
2393     uint32_t* const b444[2], uint8_t* const dst) {
2394   const __m128i s0 =
2395       LoadUnaligned16Msan(src0, kOverreadInBytesPass2_128 - width);
2396   __m128i ma0, b0, sq_128[2];
2397   __m256i mas[3], sq[3], bs[3];
2398   sq_128[0] = SquareLo8(s0);
2399   BoxFilterPreProcess3Lo(s0, scale, sum3, square_sum3, sq_128, &ma0, &b0);
2400   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2401   mas[0] = SetrM128i(ma0, ma0);
2402   bs[0] = SetrM128i(b0, b0);
2403 
2404   int x = 0;
2405   do {
2406     __m256i ma[4], b[4][2], ma3[3];
2407     BoxFilterPreProcess3(src0 + x + 8,
2408                          x + 8 + kOverreadInBytesPass2_256 - width, x + 8,
2409                          sum_width, scale, sum3, square_sum3, sq, mas, bs);
2410     Prepare3_8(mas, ma3);
2411     Store343_444Lo(ma3, bs + 0, x + 0, &ma[2], b[2], ma343[2], ma444[1],
2412                    b343[2], b444[1]);
2413     Store343_444Hi(ma3, bs + 1, x + 16, &ma[3], b[3], ma343[2], ma444[1],
2414                    b343[2], b444[1]);
2415     const __m256i sr = LoadUnaligned32(src + x);
2416     const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256());
2417     const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256());
2418     ma[0] = LoadAligned32(ma343[0] + x);
2419     ma[1] = LoadAligned32(ma444[0] + x);
2420     LoadAligned64(b343[0] + x, b[0]);
2421     LoadAligned64(b444[0] + x, b[1]);
2422     const __m256i p0 = CalculateFilteredOutputPass2(sr_lo, ma, b);
2423     ma[1] = LoadAligned32(ma343[0] + x + 16);
2424     ma[2] = LoadAligned32(ma444[0] + x + 16);
2425     LoadAligned64(b343[0] + x + 16, b[1]);
2426     LoadAligned64(b444[0] + x + 16, b[2]);
2427     const __m256i p1 = CalculateFilteredOutputPass2(sr_hi, ma + 1, b + 1);
2428     const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0);
2429     const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0);
2430     StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
2431     sq[0] = sq[2];
2432     mas[0] = mas[2];
2433     bs[0] = bs[2];
2434     x += 32;
2435   } while (x < width);
2436 }
2437 
BoxFilter(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint8_t * const dst)2438 LIBGAV1_ALWAYS_INLINE void BoxFilter(
2439     const uint8_t* const src, const uint8_t* const src0,
2440     const uint8_t* const src1, const ptrdiff_t stride, const int width,
2441     const uint16_t scales[2], const int16_t w0, const int16_t w2,
2442     uint16_t* const sum3[4], uint16_t* const sum5[5],
2443     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2444     const ptrdiff_t sum_width, uint16_t* const ma343[4],
2445     uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4],
2446     uint32_t* const b444[3], uint32_t* const b565[2], uint8_t* const dst) {
2447   __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0;
2448   __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5];
2449   s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2450   s[1] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2451   sq_128[0][0] = SquareLo8(s[0]);
2452   sq_128[1][0] = SquareLo8(s[1]);
2453   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128,
2454                         ma3_128, b3_128, &ma5_0, &b5_0);
2455   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2456   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2457   ma3[0][0] = SetrM128i(ma3_128[0], ma3_128[0]);
2458   ma3[1][0] = SetrM128i(ma3_128[1], ma3_128[1]);
2459   ma5[0] = SetrM128i(ma5_0, ma5_0);
2460   b3[0][0] = SetrM128i(b3_128[0], b3_128[0]);
2461   b3[1][0] = SetrM128i(b3_128[1], b3_128[1]);
2462   b5[0] = SetrM128i(b5_0, b5_0);
2463 
2464   int x = 0;
2465   do {
2466     __m256i ma[3][3], mat[3][3], b[3][3][2], p[2][2], ma3x[2][3], ma5x[3];
2467     BoxFilterPreProcess(src0 + x + 8, src1 + x + 8,
2468                         x + 8 + kOverreadInBytesPass1_256 - width, x + 8,
2469                         scales, sum3, sum5, square_sum3, square_sum5, sum_width,
2470                         sq, ma3, b3, ma5, b5);
2471     Prepare3_8(ma3[0], ma3x[0]);
2472     Prepare3_8(ma3[1], ma3x[1]);
2473     Prepare3_8(ma5, ma5x);
2474     Store343_444Lo(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], b[1][2], b[2][1],
2475                    ma343[2], ma444[1], b343[2], b444[1]);
2476     Store343_444Lo(ma3x[1], b3[1], x, &ma[2][2], b[2][2], ma343[3], ma444[2],
2477                    b343[3], b444[2]);
2478     ma[0][1] = Sum565Lo(ma5x);
2479     ma[0][2] = Sum565Hi(ma5x);
2480     mat[0][1] = ma[0][2];
2481     StoreAligned64(ma565[1] + x, ma[0] + 1);
2482     Sum565W(b5, b[0][1]);
2483     StoreAligned64(b565[1] + x, b[0][1]);
2484     const __m256i sr0 = LoadUnaligned32(src + x);
2485     const __m256i sr1 = LoadUnaligned32(src + stride + x);
2486     const __m256i sr0_lo = _mm256_unpacklo_epi8(sr0, _mm256_setzero_si256());
2487     const __m256i sr1_lo = _mm256_unpacklo_epi8(sr1, _mm256_setzero_si256());
2488     ma[0][0] = LoadAligned32(ma565[0] + x);
2489     LoadAligned64(b565[0] + x, b[0][0]);
2490     p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]);
2491     p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]);
2492     ma[1][0] = LoadAligned32(ma343[0] + x);
2493     ma[1][1] = LoadAligned32(ma444[0] + x);
2494     LoadAligned64(b343[0] + x, b[1][0]);
2495     LoadAligned64(b444[0] + x, b[1][1]);
2496     p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]);
2497     const __m256i d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2);
2498     ma[2][0] = LoadAligned32(ma343[1] + x);
2499     LoadAligned64(b343[1] + x, b[2][0]);
2500     p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]);
2501     const __m256i d10 = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2);
2502 
2503     Sum565W(b5 + 1, b[0][1]);
2504     StoreAligned64(b565[1] + x + 16, b[0][1]);
2505     Store343_444Hi(ma3x[0], b3[0] + 1, x + 16, &mat[1][2], &mat[2][1], b[1][2],
2506                    b[2][1], ma343[2], ma444[1], b343[2], b444[1]);
2507     Store343_444Hi(ma3x[1], b3[1] + 1, x + 16, &mat[2][2], b[2][2], ma343[3],
2508                    ma444[2], b343[3], b444[2]);
2509     const __m256i sr0_hi = _mm256_unpackhi_epi8(sr0, _mm256_setzero_si256());
2510     const __m256i sr1_hi = _mm256_unpackhi_epi8(sr1, _mm256_setzero_si256());
2511     mat[0][0] = LoadAligned32(ma565[0] + x + 16);
2512     LoadAligned64(b565[0] + x + 16, b[0][0]);
2513     p[0][0] = CalculateFilteredOutputPass1(sr0_hi, mat[0], b[0]);
2514     p[1][0] = CalculateFilteredOutput<4>(sr1_hi, mat[0][1], b[0][1]);
2515     mat[1][0] = LoadAligned32(ma343[0] + x + 16);
2516     mat[1][1] = LoadAligned32(ma444[0] + x + 16);
2517     LoadAligned64(b343[0] + x + 16, b[1][0]);
2518     LoadAligned64(b444[0] + x + 16, b[1][1]);
2519     p[0][1] = CalculateFilteredOutputPass2(sr0_hi, mat[1], b[1]);
2520     const __m256i d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2);
2521     mat[2][0] = LoadAligned32(ma343[1] + x + 16);
2522     LoadAligned64(b343[1] + x + 16, b[2][0]);
2523     p[1][1] = CalculateFilteredOutputPass2(sr1_hi, mat[2], b[2]);
2524     const __m256i d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2);
2525     StoreUnaligned32(dst + x, _mm256_packus_epi16(d00, d01));
2526     StoreUnaligned32(dst + stride + x, _mm256_packus_epi16(d10, d11));
2527     sq[0][0] = sq[0][2];
2528     sq[1][0] = sq[1][2];
2529     ma3[0][0] = ma3[0][2];
2530     ma3[1][0] = ma3[1][2];
2531     ma5[0] = ma5[2];
2532     b3[0][0] = b3[0][2];
2533     b3[1][0] = b3[1][2];
2534     b5[0] = b5[2];
2535     x += 32;
2536   } while (x < width);
2537 }
2538 
BoxFilterLastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const ptrdiff_t sum_width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343,uint16_t * const ma444,uint16_t * const ma565,uint32_t * const b343,uint32_t * const b444,uint32_t * const b565,uint8_t * const dst)2539 inline void BoxFilterLastRow(
2540     const uint8_t* const src, const uint8_t* const src0, const int width,
2541     const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0,
2542     const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5],
2543     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2544     uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565,
2545     uint32_t* const b343, uint32_t* const b444, uint32_t* const b565,
2546     uint8_t* const dst) {
2547   const __m128i s0 =
2548       LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2549   __m128i ma3_0, ma5_0, b3_0, b5_0, sq_128[2];
2550   __m256i ma3[3], ma5[3], sq[3], b3[3], b5[3];
2551   sq_128[0] = SquareLo8(s0);
2552   BoxFilterPreProcessLastRowLo(s0, scales, sum3, sum5, square_sum3, square_sum5,
2553                                sq_128, &ma3_0, &ma5_0, &b3_0, &b5_0);
2554   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2555   ma3[0] = SetrM128i(ma3_0, ma3_0);
2556   ma5[0] = SetrM128i(ma5_0, ma5_0);
2557   b3[0] = SetrM128i(b3_0, b3_0);
2558   b5[0] = SetrM128i(b5_0, b5_0);
2559 
2560   int x = 0;
2561   do {
2562     __m256i ma[3], mat[3], b[3][2], p[2], ma3x[3], ma5x[3];
2563     BoxFilterPreProcessLastRow(src0 + x + 8,
2564                                x + 8 + kOverreadInBytesPass1_256 - width,
2565                                sum_width, x + 8, scales, sum3, sum5,
2566                                square_sum3, square_sum5, sq, ma3, ma5, b3, b5);
2567     Prepare3_8(ma3, ma3x);
2568     Prepare3_8(ma5, ma5x);
2569     ma[1] = Sum565Lo(ma5x);
2570     Sum565W(b5, b[1]);
2571     ma[2] = Sum343Lo(ma3x);
2572     Sum343W(b3, b[2]);
2573     const __m256i sr = LoadUnaligned32(src + x);
2574     const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256());
2575     ma[0] = LoadAligned32(ma565 + x);
2576     LoadAligned64(b565 + x, b[0]);
2577     p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b);
2578     ma[0] = LoadAligned32(ma343 + x);
2579     ma[1] = LoadAligned32(ma444 + x);
2580     LoadAligned64(b343 + x, b[0]);
2581     LoadAligned64(b444 + x, b[1]);
2582     p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b);
2583     const __m256i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2);
2584 
2585     mat[1] = Sum565Hi(ma5x);
2586     Sum565W(b5 + 1, b[1]);
2587     mat[2] = Sum343Hi(ma3x);
2588     Sum343W(b3 + 1, b[2]);
2589     const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256());
2590     mat[0] = LoadAligned32(ma565 + x + 16);
2591     LoadAligned64(b565 + x + 16, b[0]);
2592     p[0] = CalculateFilteredOutputPass1(sr_hi, mat, b);
2593     mat[0] = LoadAligned32(ma343 + x + 16);
2594     mat[1] = LoadAligned32(ma444 + x + 16);
2595     LoadAligned64(b343 + x + 16, b[0]);
2596     LoadAligned64(b444 + x + 16, b[1]);
2597     p[1] = CalculateFilteredOutputPass2(sr_hi, mat, b);
2598     const __m256i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2);
2599     StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
2600     sq[0] = sq[2];
2601     ma3[0] = ma3[2];
2602     ma5[0] = ma5[2];
2603     b3[0] = b3[2];
2604     b5[0] = b5[2];
2605     x += 32;
2606   } while (x < width);
2607 }
2608 
BoxFilterProcess(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2609 LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
2610     const RestorationUnitInfo& restoration_info, const uint8_t* src,
2611     const ptrdiff_t stride, const uint8_t* const top_border,
2612     const ptrdiff_t top_border_stride, const uint8_t* bottom_border,
2613     const ptrdiff_t bottom_border_stride, const int width, const int height,
2614     SgrBuffer* const sgr_buffer, uint8_t* dst) {
2615   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2616   const auto sum_width = temp_stride + 8;
2617   const auto sum_stride = temp_stride + 32;
2618   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2619   const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
2620   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2621   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2622   const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
2623   uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
2624   uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
2625   sum3[0] = sgr_buffer->sum3 + kSumOffset;
2626   square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset;
2627   ma343[0] = sgr_buffer->ma343;
2628   b343[0] = sgr_buffer->b343;
2629   for (int i = 1; i <= 3; ++i) {
2630     sum3[i] = sum3[i - 1] + sum_stride;
2631     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2632     ma343[i] = ma343[i - 1] + temp_stride;
2633     b343[i] = b343[i - 1] + temp_stride;
2634   }
2635   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2636   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2637   for (int i = 1; i <= 4; ++i) {
2638     sum5[i] = sum5[i - 1] + sum_stride;
2639     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2640   }
2641   ma444[0] = sgr_buffer->ma444;
2642   b444[0] = sgr_buffer->b444;
2643   for (int i = 1; i <= 2; ++i) {
2644     ma444[i] = ma444[i - 1] + temp_stride;
2645     b444[i] = b444[i - 1] + temp_stride;
2646   }
2647   ma565[0] = sgr_buffer->ma565;
2648   ma565[1] = ma565[0] + temp_stride;
2649   b565[0] = sgr_buffer->b565;
2650   b565[1] = b565[0] + temp_stride;
2651   assert(scales[0] != 0);
2652   assert(scales[1] != 0);
2653   BoxSum(top_border, top_border_stride, width, sum_stride, temp_stride, sum3[0],
2654          sum5[1], square_sum3[0], square_sum5[1]);
2655   sum5[0] = sum5[1];
2656   square_sum5[0] = square_sum5[1];
2657   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
2658   BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
2659                          square_sum5, sum_width, ma343, ma444[0], ma565[0],
2660                          b343, b444[0], b565[0]);
2661   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2662   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2663 
2664   for (int y = (height >> 1) - 1; y > 0; --y) {
2665     Circulate4PointersBy2<uint16_t>(sum3);
2666     Circulate4PointersBy2<uint32_t>(square_sum3);
2667     Circulate5PointersBy2<uint16_t>(sum5);
2668     Circulate5PointersBy2<uint32_t>(square_sum5);
2669     BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
2670               scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width,
2671               ma343, ma444, ma565, b343, b444, b565, dst);
2672     src += 2 * stride;
2673     dst += 2 * stride;
2674     Circulate4PointersBy2<uint16_t>(ma343);
2675     Circulate4PointersBy2<uint32_t>(b343);
2676     std::swap(ma444[0], ma444[2]);
2677     std::swap(b444[0], b444[2]);
2678     std::swap(ma565[0], ma565[1]);
2679     std::swap(b565[0], b565[1]);
2680   }
2681 
2682   Circulate4PointersBy2<uint16_t>(sum3);
2683   Circulate4PointersBy2<uint32_t>(square_sum3);
2684   Circulate5PointersBy2<uint16_t>(sum5);
2685   Circulate5PointersBy2<uint32_t>(square_sum5);
2686   if ((height & 1) == 0 || height > 1) {
2687     const uint8_t* sr[2];
2688     if ((height & 1) == 0) {
2689       sr[0] = bottom_border;
2690       sr[1] = bottom_border + bottom_border_stride;
2691     } else {
2692       sr[0] = src + 2 * stride;
2693       sr[1] = bottom_border;
2694     }
2695     BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
2696               square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343,
2697               b444, b565, dst);
2698   }
2699   if ((height & 1) != 0) {
2700     if (height > 1) {
2701       src += 2 * stride;
2702       dst += 2 * stride;
2703       Circulate4PointersBy2<uint16_t>(sum3);
2704       Circulate4PointersBy2<uint32_t>(square_sum3);
2705       Circulate5PointersBy2<uint16_t>(sum5);
2706       Circulate5PointersBy2<uint32_t>(square_sum5);
2707       Circulate4PointersBy2<uint16_t>(ma343);
2708       Circulate4PointersBy2<uint32_t>(b343);
2709       std::swap(ma444[0], ma444[2]);
2710       std::swap(b444[0], b444[2]);
2711       std::swap(ma565[0], ma565[1]);
2712       std::swap(b565[0], b565[1]);
2713     }
2714     BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width,
2715                      sum_width, scales, w0, w2, sum3, sum5, square_sum3,
2716                      square_sum5, ma343[0], ma444[0], ma565[0], b343[0],
2717                      b444[0], b565[0], dst);
2718   }
2719 }
2720 
BoxFilterProcessPass1(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2721 inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
2722                                   const uint8_t* src, const ptrdiff_t stride,
2723                                   const uint8_t* const top_border,
2724                                   const ptrdiff_t top_border_stride,
2725                                   const uint8_t* bottom_border,
2726                                   const ptrdiff_t bottom_border_stride,
2727                                   const int width, const int height,
2728                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
2729   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2730   const auto sum_width = temp_stride + 8;
2731   const auto sum_stride = temp_stride + 32;
2732   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2733   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
2734   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2735   uint16_t *sum5[5], *ma565[2];
2736   uint32_t *square_sum5[5], *b565[2];
2737   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2738   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2739   for (int i = 1; i <= 4; ++i) {
2740     sum5[i] = sum5[i - 1] + sum_stride;
2741     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2742   }
2743   ma565[0] = sgr_buffer->ma565;
2744   ma565[1] = ma565[0] + temp_stride;
2745   b565[0] = sgr_buffer->b565;
2746   b565[1] = b565[0] + temp_stride;
2747   assert(scale != 0);
2748   BoxSum<5>(top_border, top_border_stride, width, sum_stride, temp_stride,
2749             sum5[1], square_sum5[1]);
2750   sum5[0] = sum5[1];
2751   square_sum5[0] = square_sum5[1];
2752   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
2753   BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width,
2754                           ma565[0], b565[0]);
2755   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2756   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2757 
2758   for (int y = (height >> 1) - 1; y > 0; --y) {
2759     Circulate5PointersBy2<uint16_t>(sum5);
2760     Circulate5PointersBy2<uint32_t>(square_sum5);
2761     BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
2762                    square_sum5, width, sum_width, scale, w0, ma565, b565, dst);
2763     src += 2 * stride;
2764     dst += 2 * stride;
2765     std::swap(ma565[0], ma565[1]);
2766     std::swap(b565[0], b565[1]);
2767   }
2768 
2769   Circulate5PointersBy2<uint16_t>(sum5);
2770   Circulate5PointersBy2<uint32_t>(square_sum5);
2771   if ((height & 1) == 0 || height > 1) {
2772     const uint8_t* sr[2];
2773     if ((height & 1) == 0) {
2774       sr[0] = bottom_border;
2775       sr[1] = bottom_border + bottom_border_stride;
2776     } else {
2777       sr[0] = src + 2 * stride;
2778       sr[1] = bottom_border;
2779     }
2780     BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
2781                    sum_width, scale, w0, ma565, b565, dst);
2782   }
2783   if ((height & 1) != 0) {
2784     src += 3;
2785     if (height > 1) {
2786       src += 2 * stride;
2787       dst += 2 * stride;
2788       std::swap(ma565[0], ma565[1]);
2789       std::swap(b565[0], b565[1]);
2790       Circulate5PointersBy2<uint16_t>(sum5);
2791       Circulate5PointersBy2<uint32_t>(square_sum5);
2792     }
2793     BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width,
2794                           sum_width, scale, w0, sum5, square_sum5, ma565[0],
2795                           b565[0], dst);
2796   }
2797 }
2798 
BoxFilterProcessPass2(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2799 inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
2800                                   const uint8_t* src, const ptrdiff_t stride,
2801                                   const uint8_t* const top_border,
2802                                   const ptrdiff_t top_border_stride,
2803                                   const uint8_t* bottom_border,
2804                                   const ptrdiff_t bottom_border_stride,
2805                                   const int width, const int height,
2806                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
2807   assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
2808   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2809   const auto sum_width = temp_stride + 8;
2810   const auto sum_stride = temp_stride + 32;
2811   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2812   const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
2813   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2814   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
2815   uint16_t *sum3[3], *ma343[3], *ma444[2];
2816   uint32_t *square_sum3[3], *b343[3], *b444[2];
2817   sum3[0] = sgr_buffer->sum3 + kSumOffset;
2818   square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset;
2819   ma343[0] = sgr_buffer->ma343;
2820   b343[0] = sgr_buffer->b343;
2821   for (int i = 1; i <= 2; ++i) {
2822     sum3[i] = sum3[i - 1] + sum_stride;
2823     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2824     ma343[i] = ma343[i - 1] + temp_stride;
2825     b343[i] = b343[i - 1] + temp_stride;
2826   }
2827   ma444[0] = sgr_buffer->ma444;
2828   ma444[1] = ma444[0] + temp_stride;
2829   b444[0] = sgr_buffer->b444;
2830   b444[1] = b444[0] + temp_stride;
2831   assert(scale != 0);
2832   BoxSum<3>(top_border, top_border_stride, width, sum_stride, temp_stride,
2833             sum3[0], square_sum3[0]);
2834   BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3,
2835                                  sum_width, ma343[0], nullptr, b343[0],
2836                                  nullptr);
2837   Circulate3PointersBy1<uint16_t>(sum3);
2838   Circulate3PointersBy1<uint32_t>(square_sum3);
2839   const uint8_t* s;
2840   if (height > 1) {
2841     s = src + stride;
2842   } else {
2843     s = bottom_border;
2844     bottom_border += bottom_border_stride;
2845   }
2846   BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width,
2847                                 ma343[1], ma444[0], b343[1], b444[0]);
2848 
2849   for (int y = height - 2; y > 0; --y) {
2850     Circulate3PointersBy1<uint16_t>(sum3);
2851     Circulate3PointersBy1<uint32_t>(square_sum3);
2852     BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3,
2853                    square_sum3, ma343, ma444, b343, b444, dst);
2854     src += stride;
2855     dst += stride;
2856     Circulate3PointersBy1<uint16_t>(ma343);
2857     Circulate3PointersBy1<uint32_t>(b343);
2858     std::swap(ma444[0], ma444[1]);
2859     std::swap(b444[0], b444[1]);
2860   }
2861 
2862   int y = std::min(height, 2);
2863   src += 2;
2864   do {
2865     Circulate3PointersBy1<uint16_t>(sum3);
2866     Circulate3PointersBy1<uint32_t>(square_sum3);
2867     BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3,
2868                    square_sum3, ma343, ma444, b343, b444, dst);
2869     src += stride;
2870     dst += stride;
2871     bottom_border += bottom_border_stride;
2872     Circulate3PointersBy1<uint16_t>(ma343);
2873     Circulate3PointersBy1<uint32_t>(b343);
2874     std::swap(ma444[0], ma444[1]);
2875     std::swap(b444[0], b444[1]);
2876   } while (--y != 0);
2877 }
2878 
2879 // If |width| is non-multiple of 32, up to 31 more pixels are written to |dest|
2880 // in the end of each row. It is safe to overwrite the output as it will not be
2881 // part of the visible frame.
SelfGuidedFilter_AVX2(const RestorationUnitInfo & restoration_info,const void * const source,const ptrdiff_t stride,const void * const top_border,const ptrdiff_t top_border_stride,const void * const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * const restoration_buffer,void * const dest)2882 void SelfGuidedFilter_AVX2(
2883     const RestorationUnitInfo& restoration_info, const void* const source,
2884     const ptrdiff_t stride, const void* const top_border,
2885     const ptrdiff_t top_border_stride, const void* const bottom_border,
2886     const ptrdiff_t bottom_border_stride, const int width, const int height,
2887     RestorationBuffer* const restoration_buffer, void* const dest) {
2888   const int index = restoration_info.sgr_proj_info.index;
2889   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
2890   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
2891   const auto* const src = static_cast<const uint8_t*>(source);
2892   const auto* top = static_cast<const uint8_t*>(top_border);
2893   const auto* bottom = static_cast<const uint8_t*>(bottom_border);
2894   auto* const dst = static_cast<uint8_t*>(dest);
2895   SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
2896   if (radius_pass_1 == 0) {
2897     // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
2898     // following assertion.
2899     assert(radius_pass_0 != 0);
2900     BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3,
2901                           top_border_stride, bottom - 3, bottom_border_stride,
2902                           width, height, sgr_buffer, dst);
2903   } else if (radius_pass_0 == 0) {
2904     BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2,
2905                           top_border_stride, bottom - 2, bottom_border_stride,
2906                           width, height, sgr_buffer, dst);
2907   } else {
2908     BoxFilterProcess(restoration_info, src - 3, stride, top - 3,
2909                      top_border_stride, bottom - 3, bottom_border_stride, width,
2910                      height, sgr_buffer, dst);
2911   }
2912 }
2913 
Init8bpp()2914 void Init8bpp() {
2915   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
2916   assert(dsp != nullptr);
2917 #if DSP_ENABLED_8BPP_AVX2(WienerFilter)
2918   dsp->loop_restorations[0] = WienerFilter_AVX2;
2919 #endif
2920 #if DSP_ENABLED_8BPP_AVX2(SelfGuidedFilter)
2921   dsp->loop_restorations[1] = SelfGuidedFilter_AVX2;
2922 #endif
2923 }
2924 
2925 }  // namespace
2926 }  // namespace low_bitdepth
2927 
LoopRestorationInit_AVX2()2928 void LoopRestorationInit_AVX2() { low_bitdepth::Init8bpp(); }
2929 
2930 }  // namespace dsp
2931 }  // namespace libgav1
2932 
2933 #else   // !LIBGAV1_TARGETING_AVX2
2934 namespace libgav1 {
2935 namespace dsp {
2936 
LoopRestorationInit_AVX2()2937 void LoopRestorationInit_AVX2() {}
2938 
2939 }  // namespace dsp
2940 }  // namespace libgav1
2941 #endif  // LIBGAV1_TARGETING_AVX2
2942