1 /*
2  * Copyright 2019 The libgav1 Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
18 #define LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
19 
20 #include "src/utils/cpu.h"
21 
22 #if LIBGAV1_ENABLE_NEON
23 
24 #include <arm_neon.h>
25 
26 #include <cstdint>
27 #include <cstring>
28 
29 #if 0
30 #include <cstdio>
31 #include <string>
32 
33 constexpr bool kEnablePrintRegs = true;
34 
35 union DebugRegister {
36   int8_t i8[8];
37   int16_t i16[4];
38   int32_t i32[2];
39   uint8_t u8[8];
40   uint16_t u16[4];
41   uint32_t u32[2];
42 };
43 
44 union DebugRegisterQ {
45   int8_t i8[16];
46   int16_t i16[8];
47   int32_t i32[4];
48   uint8_t u8[16];
49   uint16_t u16[8];
50   uint32_t u32[4];
51 };
52 
53 // Quite useful macro for debugging. Left here for convenience.
54 inline void PrintVect(const DebugRegister r, const char* const name, int size) {
55   int n;
56   if (kEnablePrintRegs) {
57     fprintf(stderr, "%s\t: ", name);
58     if (size == 8) {
59       for (n = 0; n < 8; ++n) fprintf(stderr, "%.2x ", r.u8[n]);
60     } else if (size == 16) {
61       for (n = 0; n < 4; ++n) fprintf(stderr, "%.4x ", r.u16[n]);
62     } else if (size == 32) {
63       for (n = 0; n < 2; ++n) fprintf(stderr, "%.8x ", r.u32[n]);
64     }
65     fprintf(stderr, "\n");
66   }
67 }
68 
69 // Debugging macro for 128-bit types.
70 inline void PrintVectQ(const DebugRegisterQ r, const char* const name,
71                        int size) {
72   int n;
73   if (kEnablePrintRegs) {
74     fprintf(stderr, "%s\t: ", name);
75     if (size == 8) {
76       for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", r.u8[n]);
77     } else if (size == 16) {
78       for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", r.u16[n]);
79     } else if (size == 32) {
80       for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", r.u32[n]);
81     }
82     fprintf(stderr, "\n");
83   }
84 }
85 
86 inline void PrintReg(const int32x4x2_t val, const std::string& name) {
87   DebugRegisterQ r;
88   vst1q_s32(r.i32, val.val[0]);
89   const std::string name0 = name + std::string(".val[0]");
90   PrintVectQ(r, name0.c_str(), 32);
91   vst1q_s32(r.i32, val.val[1]);
92   const std::string name1 = name + std::string(".val[1]");
93   PrintVectQ(r, name1.c_str(), 32);
94 }
95 
96 inline void PrintReg(const uint32x4_t val, const char* name) {
97   DebugRegisterQ r;
98   vst1q_u32(r.u32, val);
99   PrintVectQ(r, name, 32);
100 }
101 
102 inline void PrintReg(const uint32x2_t val, const char* name) {
103   DebugRegister r;
104   vst1_u32(r.u32, val);
105   PrintVect(r, name, 32);
106 }
107 
108 inline void PrintReg(const uint16x8_t val, const char* name) {
109   DebugRegisterQ r;
110   vst1q_u16(r.u16, val);
111   PrintVectQ(r, name, 16);
112 }
113 
114 inline void PrintReg(const uint16x4_t val, const char* name) {
115   DebugRegister r;
116   vst1_u16(r.u16, val);
117   PrintVect(r, name, 16);
118 }
119 
120 inline void PrintReg(const uint8x16_t val, const char* name) {
121   DebugRegisterQ r;
122   vst1q_u8(r.u8, val);
123   PrintVectQ(r, name, 8);
124 }
125 
126 inline void PrintReg(const uint8x8_t val, const char* name) {
127   DebugRegister r;
128   vst1_u8(r.u8, val);
129   PrintVect(r, name, 8);
130 }
131 
132 inline void PrintReg(const int32x4_t val, const char* name) {
133   DebugRegisterQ r;
134   vst1q_s32(r.i32, val);
135   PrintVectQ(r, name, 32);
136 }
137 
138 inline void PrintReg(const int32x2_t val, const char* name) {
139   DebugRegister r;
140   vst1_s32(r.i32, val);
141   PrintVect(r, name, 32);
142 }
143 
144 inline void PrintReg(const int16x8_t val, const char* name) {
145   DebugRegisterQ r;
146   vst1q_s16(r.i16, val);
147   PrintVectQ(r, name, 16);
148 }
149 
150 inline void PrintReg(const int16x4_t val, const char* name) {
151   DebugRegister r;
152   vst1_s16(r.i16, val);
153   PrintVect(r, name, 16);
154 }
155 
156 inline void PrintReg(const int8x16_t val, const char* name) {
157   DebugRegisterQ r;
158   vst1q_s8(r.i8, val);
159   PrintVectQ(r, name, 8);
160 }
161 
162 inline void PrintReg(const int8x8_t val, const char* name) {
163   DebugRegister r;
164   vst1_s8(r.i8, val);
165   PrintVect(r, name, 8);
166 }
167 
168 // Print an individual (non-vector) value in decimal format.
169 inline void PrintReg(const int x, const char* name) {
170   if (kEnablePrintRegs) {
171     fprintf(stderr, "%s: %d\n", name, x);
172   }
173 }
174 
175 // Print an individual (non-vector) value in hexadecimal format.
176 inline void PrintHex(const int x, const char* name) {
177   if (kEnablePrintRegs) {
178     fprintf(stderr, "%s: %x\n", name, x);
179   }
180 }
181 
182 #define PR(x) PrintReg(x, #x)
183 #define PD(x) PrintReg(x, #x)
184 #define PX(x) PrintHex(x, #x)
185 
186 #endif  // 0
187 
188 namespace libgav1 {
189 namespace dsp {
190 
191 //------------------------------------------------------------------------------
192 // Load functions.
193 
194 // Load 2 uint8_t values into lanes 0 and 1. Zeros the register before loading
195 // the values. Use caution when using this in loops because it will re-zero the
196 // register before loading on every iteration.
Load2(const void * const buf)197 inline uint8x8_t Load2(const void* const buf) {
198   const uint16x4_t zero = vdup_n_u16(0);
199   uint16_t temp;
200   memcpy(&temp, buf, 2);
201   return vreinterpret_u8_u16(vld1_lane_u16(&temp, zero, 0));
202 }
203 
204 // Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1.
205 template <int lane>
Load2(const void * const buf,uint8x8_t val)206 inline uint8x8_t Load2(const void* const buf, uint8x8_t val) {
207   uint16_t temp;
208   memcpy(&temp, buf, 2);
209   return vreinterpret_u8_u16(
210       vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane));
211 }
212 
213 // Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the
214 // register before loading the values. Use caution when using this in loops
215 // because it will re-zero the register before loading on every iteration.
Load4(const void * const buf)216 inline uint8x8_t Load4(const void* const buf) {
217   const uint32x2_t zero = vdup_n_u32(0);
218   uint32_t temp;
219   memcpy(&temp, buf, 4);
220   return vreinterpret_u8_u32(vld1_lane_u32(&temp, zero, 0));
221 }
222 
223 // Load 4 uint8_t values into 4 lanes staring with |lane| * 4.
224 template <int lane>
Load4(const void * const buf,uint8x8_t val)225 inline uint8x8_t Load4(const void* const buf, uint8x8_t val) {
226   uint32_t temp;
227   memcpy(&temp, buf, 4);
228   return vreinterpret_u8_u32(
229       vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane));
230 }
231 
232 //------------------------------------------------------------------------------
233 // Store functions.
234 
235 // Propagate type information to the compiler. Without this the compiler may
236 // assume the required alignment of the type (4 bytes in the case of uint32_t)
237 // and add alignment hints to the memory access.
238 template <typename T>
ValueToMem(void * const buf,T val)239 inline void ValueToMem(void* const buf, T val) {
240   memcpy(buf, &val, sizeof(val));
241 }
242 
243 // Store 4 int8_t values from the low half of an int8x8_t register.
StoreLo4(void * const buf,const int8x8_t val)244 inline void StoreLo4(void* const buf, const int8x8_t val) {
245   ValueToMem<int32_t>(buf, vget_lane_s32(vreinterpret_s32_s8(val), 0));
246 }
247 
248 // Store 4 uint8_t values from the low half of a uint8x8_t register.
StoreLo4(void * const buf,const uint8x8_t val)249 inline void StoreLo4(void* const buf, const uint8x8_t val) {
250   ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 0));
251 }
252 
253 // Store 4 uint8_t values from the high half of a uint8x8_t register.
StoreHi4(void * const buf,const uint8x8_t val)254 inline void StoreHi4(void* const buf, const uint8x8_t val) {
255   ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 1));
256 }
257 
258 // Store 2 uint8_t values from |lane| * 2 and |lane| * 2 + 1 of a uint8x8_t
259 // register.
260 template <int lane>
Store2(void * const buf,const uint8x8_t val)261 inline void Store2(void* const buf, const uint8x8_t val) {
262   ValueToMem<uint16_t>(buf, vget_lane_u16(vreinterpret_u16_u8(val), lane));
263 }
264 
265 // Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x8_t
266 // register.
267 template <int lane>
Store2(void * const buf,const uint16x8_t val)268 inline void Store2(void* const buf, const uint16x8_t val) {
269   ValueToMem<uint32_t>(buf, vgetq_lane_u32(vreinterpretq_u32_u16(val), lane));
270 }
271 
272 // Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t
273 // register.
274 template <int lane>
Store2(uint16_t * const buf,const uint16x4_t val)275 inline void Store2(uint16_t* const buf, const uint16x4_t val) {
276   ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane));
277 }
278 
279 // Simplify code when caller has |buf| cast as uint8_t*.
Store4(void * const buf,const uint16x4_t val)280 inline void Store4(void* const buf, const uint16x4_t val) {
281   vst1_u16(static_cast<uint16_t*>(buf), val);
282 }
283 
284 // Simplify code when caller has |buf| cast as uint8_t*.
Store8(void * const buf,const uint16x8_t val)285 inline void Store8(void* const buf, const uint16x8_t val) {
286   vst1q_u16(static_cast<uint16_t*>(buf), val);
287 }
288 
289 //------------------------------------------------------------------------------
290 // Bit manipulation.
291 
292 // vshXX_n_XX() requires an immediate.
293 template <int shift>
LeftShiftVector(const uint8x8_t vector)294 inline uint8x8_t LeftShiftVector(const uint8x8_t vector) {
295   return vreinterpret_u8_u64(vshl_n_u64(vreinterpret_u64_u8(vector), shift));
296 }
297 
298 template <int shift>
RightShiftVector(const uint8x8_t vector)299 inline uint8x8_t RightShiftVector(const uint8x8_t vector) {
300   return vreinterpret_u8_u64(vshr_n_u64(vreinterpret_u64_u8(vector), shift));
301 }
302 
303 template <int shift>
RightShiftVector(const int8x8_t vector)304 inline int8x8_t RightShiftVector(const int8x8_t vector) {
305   return vreinterpret_s8_u64(vshr_n_u64(vreinterpret_u64_s8(vector), shift));
306 }
307 
308 // Shim vqtbl1_u8 for armv7.
VQTbl1U8(const uint8x16_t a,const uint8x8_t index)309 inline uint8x8_t VQTbl1U8(const uint8x16_t a, const uint8x8_t index) {
310 #if defined(__aarch64__)
311   return vqtbl1_u8(a, index);
312 #else
313   const uint8x8x2_t b = {vget_low_u8(a), vget_high_u8(a)};
314   return vtbl2_u8(b, index);
315 #endif
316 }
317 
318 // Shim vqtbl1_s8 for armv7.
VQTbl1S8(const int8x16_t a,const uint8x8_t index)319 inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) {
320 #if defined(__aarch64__)
321   return vqtbl1_s8(a, index);
322 #else
323   const int8x8x2_t b = {vget_low_s8(a), vget_high_s8(a)};
324   return vtbl2_s8(b, vreinterpret_s8_u8(index));
325 #endif
326 }
327 
328 //------------------------------------------------------------------------------
329 // Interleave.
330 
331 // vzipN is exclusive to A64.
InterleaveLow8(const uint8x8_t a,const uint8x8_t b)332 inline uint8x8_t InterleaveLow8(const uint8x8_t a, const uint8x8_t b) {
333 #if defined(__aarch64__)
334   return vzip1_u8(a, b);
335 #else
336   // Discard |.val[1]|
337   return vzip_u8(a, b).val[0];
338 #endif
339 }
340 
InterleaveLow32(const uint8x8_t a,const uint8x8_t b)341 inline uint8x8_t InterleaveLow32(const uint8x8_t a, const uint8x8_t b) {
342 #if defined(__aarch64__)
343   return vreinterpret_u8_u32(
344       vzip1_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
345 #else
346   // Discard |.val[1]|
347   return vreinterpret_u8_u32(
348       vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[0]);
349 #endif
350 }
351 
InterleaveLow32(const int8x8_t a,const int8x8_t b)352 inline int8x8_t InterleaveLow32(const int8x8_t a, const int8x8_t b) {
353 #if defined(__aarch64__)
354   return vreinterpret_s8_u32(
355       vzip1_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
356 #else
357   // Discard |.val[1]|
358   return vreinterpret_s8_u32(
359       vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[0]);
360 #endif
361 }
362 
InterleaveHigh32(const uint8x8_t a,const uint8x8_t b)363 inline uint8x8_t InterleaveHigh32(const uint8x8_t a, const uint8x8_t b) {
364 #if defined(__aarch64__)
365   return vreinterpret_u8_u32(
366       vzip2_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
367 #else
368   // Discard |.val[0]|
369   return vreinterpret_u8_u32(
370       vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[1]);
371 #endif
372 }
373 
InterleaveHigh32(const int8x8_t a,const int8x8_t b)374 inline int8x8_t InterleaveHigh32(const int8x8_t a, const int8x8_t b) {
375 #if defined(__aarch64__)
376   return vreinterpret_s8_u32(
377       vzip2_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
378 #else
379   // Discard |.val[0]|
380   return vreinterpret_s8_u32(
381       vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[1]);
382 #endif
383 }
384 
385 //------------------------------------------------------------------------------
386 // Sum.
387 
SumVector(const uint8x8_t a)388 inline uint16_t SumVector(const uint8x8_t a) {
389 #if defined(__aarch64__)
390   return vaddlv_u8(a);
391 #else
392   const uint16x4_t c = vpaddl_u8(a);
393   const uint32x2_t d = vpaddl_u16(c);
394   const uint64x1_t e = vpaddl_u32(d);
395   return static_cast<uint16_t>(vget_lane_u64(e, 0));
396 #endif  // defined(__aarch64__)
397 }
398 
SumVector(const uint32x2_t a)399 inline uint32_t SumVector(const uint32x2_t a) {
400 #if defined(__aarch64__)
401   return vaddv_u32(a);
402 #else
403   const uint64x1_t b = vpaddl_u32(a);
404   return vget_lane_u32(vreinterpret_u32_u64(b), 0);
405 #endif  // defined(__aarch64__)
406 }
407 
SumVector(const uint32x4_t a)408 inline uint32_t SumVector(const uint32x4_t a) {
409 #if defined(__aarch64__)
410   return vaddvq_u32(a);
411 #else
412   const uint64x2_t b = vpaddlq_u32(a);
413   const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b));
414   return static_cast<uint32_t>(vget_lane_u64(c, 0));
415 #endif
416 }
417 
418 //------------------------------------------------------------------------------
419 // Transpose.
420 
421 // Transpose 32 bit elements such that:
422 // a: 00 01
423 // b: 02 03
424 // returns
425 // val[0]: 00 02
426 // val[1]: 01 03
Interleave32(const uint8x8_t a,const uint8x8_t b)427 inline uint8x8x2_t Interleave32(const uint8x8_t a, const uint8x8_t b) {
428   const uint32x2_t a_32 = vreinterpret_u32_u8(a);
429   const uint32x2_t b_32 = vreinterpret_u32_u8(b);
430   const uint32x2x2_t c = vtrn_u32(a_32, b_32);
431   const uint8x8x2_t d = {vreinterpret_u8_u32(c.val[0]),
432                          vreinterpret_u8_u32(c.val[1])};
433   return d;
434 }
435 
436 // Swap high and low 32 bit elements.
Transpose32(const uint8x8_t a)437 inline uint8x8_t Transpose32(const uint8x8_t a) {
438   const uint32x2_t b = vrev64_u32(vreinterpret_u32_u8(a));
439   return vreinterpret_u8_u32(b);
440 }
441 
442 // Implement vtrnq_s64().
443 // Input:
444 // a0: 00 01 02 03 04 05 06 07
445 // a1: 16 17 18 19 20 21 22 23
446 // Output:
447 // b0.val[0]: 00 01 02 03 16 17 18 19
448 // b0.val[1]: 04 05 06 07 20 21 22 23
VtrnqS64(int32x4_t a0,int32x4_t a1)449 inline int16x8x2_t VtrnqS64(int32x4_t a0, int32x4_t a1) {
450   int16x8x2_t b0;
451   b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
452                            vreinterpret_s16_s32(vget_low_s32(a1)));
453   b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
454                            vreinterpret_s16_s32(vget_high_s32(a1)));
455   return b0;
456 }
457 
VtrnqU64(uint32x4_t a0,uint32x4_t a1)458 inline uint16x8x2_t VtrnqU64(uint32x4_t a0, uint32x4_t a1) {
459   uint16x8x2_t b0;
460   b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)),
461                            vreinterpret_u16_u32(vget_low_u32(a1)));
462   b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)),
463                            vreinterpret_u16_u32(vget_high_u32(a1)));
464   return b0;
465 }
466 
467 // Input:
468 // 00 01 02 03
469 // 10 11 12 13
470 // 20 21 22 23
471 // 30 31 32 33
Transpose4x4(uint16x4_t a[4])472 inline void Transpose4x4(uint16x4_t a[4]) {
473   // b:
474   // 00 10 02 12
475   // 01 11 03 13
476   const uint16x4x2_t b = vtrn_u16(a[0], a[1]);
477   // c:
478   // 20 30 22 32
479   // 21 31 23 33
480   const uint16x4x2_t c = vtrn_u16(a[2], a[3]);
481   // d:
482   // 00 10 20 30
483   // 02 12 22 32
484   const uint32x2x2_t d =
485       vtrn_u32(vreinterpret_u32_u16(b.val[0]), vreinterpret_u32_u16(c.val[0]));
486   // e:
487   // 01 11 21 31
488   // 03 13 23 33
489   const uint32x2x2_t e =
490       vtrn_u32(vreinterpret_u32_u16(b.val[1]), vreinterpret_u32_u16(c.val[1]));
491   a[0] = vreinterpret_u16_u32(d.val[0]);
492   a[1] = vreinterpret_u16_u32(e.val[0]);
493   a[2] = vreinterpret_u16_u32(d.val[1]);
494   a[3] = vreinterpret_u16_u32(e.val[1]);
495 }
496 
497 // Input:
498 // a: 00 01 02 03 10 11 12 13
499 // b: 20 21 22 23 30 31 32 33
500 // Output:
501 // Note that columns [1] and [2] are transposed.
502 // a: 00 10 20 30 02 12 22 32
503 // b: 01 11 21 31 03 13 23 33
Transpose4x4(uint8x8_t * a,uint8x8_t * b)504 inline void Transpose4x4(uint8x8_t* a, uint8x8_t* b) {
505   const uint16x4x2_t c =
506       vtrn_u16(vreinterpret_u16_u8(*a), vreinterpret_u16_u8(*b));
507   const uint32x2x2_t d =
508       vtrn_u32(vreinterpret_u32_u16(c.val[0]), vreinterpret_u32_u16(c.val[1]));
509   const uint8x8x2_t e =
510       vtrn_u8(vreinterpret_u8_u32(d.val[0]), vreinterpret_u8_u32(d.val[1]));
511   *a = e.val[0];
512   *b = e.val[1];
513 }
514 
515 // Reversible if the x4 values are packed next to each other.
516 // x4 input / x8 output:
517 // a0: 00 01 02 03 40 41 42 43 44
518 // a1: 10 11 12 13 50 51 52 53 54
519 // a2: 20 21 22 23 60 61 62 63 64
520 // a3: 30 31 32 33 70 71 72 73 74
521 // x8 input / x4 output:
522 // a0: 00 10 20 30 40 50 60 70
523 // a1: 01 11 21 31 41 51 61 71
524 // a2: 02 12 22 32 42 52 62 72
525 // a3: 03 13 23 33 43 53 63 73
Transpose8x4(uint8x8_t * a0,uint8x8_t * a1,uint8x8_t * a2,uint8x8_t * a3)526 inline void Transpose8x4(uint8x8_t* a0, uint8x8_t* a1, uint8x8_t* a2,
527                          uint8x8_t* a3) {
528   const uint8x8x2_t b0 = vtrn_u8(*a0, *a1);
529   const uint8x8x2_t b1 = vtrn_u8(*a2, *a3);
530 
531   const uint16x4x2_t c0 =
532       vtrn_u16(vreinterpret_u16_u8(b0.val[0]), vreinterpret_u16_u8(b1.val[0]));
533   const uint16x4x2_t c1 =
534       vtrn_u16(vreinterpret_u16_u8(b0.val[1]), vreinterpret_u16_u8(b1.val[1]));
535 
536   *a0 = vreinterpret_u8_u16(c0.val[0]);
537   *a1 = vreinterpret_u8_u16(c1.val[0]);
538   *a2 = vreinterpret_u8_u16(c0.val[1]);
539   *a3 = vreinterpret_u8_u16(c1.val[1]);
540 }
541 
542 // Input:
543 // a[0]: 00 01 02 03 04 05 06 07
544 // a[1]: 10 11 12 13 14 15 16 17
545 // a[2]: 20 21 22 23 24 25 26 27
546 // a[3]: 30 31 32 33 34 35 36 37
547 // a[4]: 40 41 42 43 44 45 46 47
548 // a[5]: 50 51 52 53 54 55 56 57
549 // a[6]: 60 61 62 63 64 65 66 67
550 // a[7]: 70 71 72 73 74 75 76 77
551 
552 // Output:
553 // a[0]: 00 10 20 30 40 50 60 70
554 // a[1]: 01 11 21 31 41 51 61 71
555 // a[2]: 02 12 22 32 42 52 62 72
556 // a[3]: 03 13 23 33 43 53 63 73
557 // a[4]: 04 14 24 34 44 54 64 74
558 // a[5]: 05 15 25 35 45 55 65 75
559 // a[6]: 06 16 26 36 46 56 66 76
560 // a[7]: 07 17 27 37 47 57 67 77
Transpose8x8(int8x8_t a[8])561 inline void Transpose8x8(int8x8_t a[8]) {
562   // Swap 8 bit elements. Goes from:
563   // a[0]: 00 01 02 03 04 05 06 07
564   // a[1]: 10 11 12 13 14 15 16 17
565   // a[2]: 20 21 22 23 24 25 26 27
566   // a[3]: 30 31 32 33 34 35 36 37
567   // a[4]: 40 41 42 43 44 45 46 47
568   // a[5]: 50 51 52 53 54 55 56 57
569   // a[6]: 60 61 62 63 64 65 66 67
570   // a[7]: 70 71 72 73 74 75 76 77
571   // to:
572   // b0.val[0]: 00 10 02 12 04 14 06 16  40 50 42 52 44 54 46 56
573   // b0.val[1]: 01 11 03 13 05 15 07 17  41 51 43 53 45 55 47 57
574   // b1.val[0]: 20 30 22 32 24 34 26 36  60 70 62 72 64 74 66 76
575   // b1.val[1]: 21 31 23 33 25 35 27 37  61 71 63 73 65 75 67 77
576   const int8x16x2_t b0 =
577       vtrnq_s8(vcombine_s8(a[0], a[4]), vcombine_s8(a[1], a[5]));
578   const int8x16x2_t b1 =
579       vtrnq_s8(vcombine_s8(a[2], a[6]), vcombine_s8(a[3], a[7]));
580 
581   // Swap 16 bit elements resulting in:
582   // c0.val[0]: 00 10 20 30 04 14 24 34  40 50 60 70 44 54 64 74
583   // c0.val[1]: 02 12 22 32 06 16 26 36  42 52 62 72 46 56 66 76
584   // c1.val[0]: 01 11 21 31 05 15 25 35  41 51 61 71 45 55 65 75
585   // c1.val[1]: 03 13 23 33 07 17 27 37  43 53 63 73 47 57 67 77
586   const int16x8x2_t c0 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[0]),
587                                    vreinterpretq_s16_s8(b1.val[0]));
588   const int16x8x2_t c1 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[1]),
589                                    vreinterpretq_s16_s8(b1.val[1]));
590 
591   // Unzip 32 bit elements resulting in:
592   // d0.val[0]: 00 10 20 30 40 50 60 70  01 11 21 31 41 51 61 71
593   // d0.val[1]: 04 14 24 34 44 54 64 74  05 15 25 35 45 55 65 75
594   // d1.val[0]: 02 12 22 32 42 52 62 72  03 13 23 33 43 53 63 73
595   // d1.val[1]: 06 16 26 36 46 56 66 76  07 17 27 37 47 57 67 77
596   const int32x4x2_t d0 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[0]),
597                                    vreinterpretq_s32_s16(c1.val[0]));
598   const int32x4x2_t d1 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[1]),
599                                    vreinterpretq_s32_s16(c1.val[1]));
600 
601   a[0] = vreinterpret_s8_s32(vget_low_s32(d0.val[0]));
602   a[1] = vreinterpret_s8_s32(vget_high_s32(d0.val[0]));
603   a[2] = vreinterpret_s8_s32(vget_low_s32(d1.val[0]));
604   a[3] = vreinterpret_s8_s32(vget_high_s32(d1.val[0]));
605   a[4] = vreinterpret_s8_s32(vget_low_s32(d0.val[1]));
606   a[5] = vreinterpret_s8_s32(vget_high_s32(d0.val[1]));
607   a[6] = vreinterpret_s8_s32(vget_low_s32(d1.val[1]));
608   a[7] = vreinterpret_s8_s32(vget_high_s32(d1.val[1]));
609 }
610 
611 // Unsigned.
Transpose8x8(uint8x8_t a[8])612 inline void Transpose8x8(uint8x8_t a[8]) {
613   const uint8x16x2_t b0 =
614       vtrnq_u8(vcombine_u8(a[0], a[4]), vcombine_u8(a[1], a[5]));
615   const uint8x16x2_t b1 =
616       vtrnq_u8(vcombine_u8(a[2], a[6]), vcombine_u8(a[3], a[7]));
617 
618   const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
619                                     vreinterpretq_u16_u8(b1.val[0]));
620   const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
621                                     vreinterpretq_u16_u8(b1.val[1]));
622 
623   const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]),
624                                     vreinterpretq_u32_u16(c1.val[0]));
625   const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]),
626                                     vreinterpretq_u32_u16(c1.val[1]));
627 
628   a[0] = vreinterpret_u8_u32(vget_low_u32(d0.val[0]));
629   a[1] = vreinterpret_u8_u32(vget_high_u32(d0.val[0]));
630   a[2] = vreinterpret_u8_u32(vget_low_u32(d1.val[0]));
631   a[3] = vreinterpret_u8_u32(vget_high_u32(d1.val[0]));
632   a[4] = vreinterpret_u8_u32(vget_low_u32(d0.val[1]));
633   a[5] = vreinterpret_u8_u32(vget_high_u32(d0.val[1]));
634   a[6] = vreinterpret_u8_u32(vget_low_u32(d1.val[1]));
635   a[7] = vreinterpret_u8_u32(vget_high_u32(d1.val[1]));
636 }
637 
Transpose8x8(uint8x8_t in[8],uint8x16_t out[4])638 inline void Transpose8x8(uint8x8_t in[8], uint8x16_t out[4]) {
639   const uint8x16x2_t a0 =
640       vtrnq_u8(vcombine_u8(in[0], in[4]), vcombine_u8(in[1], in[5]));
641   const uint8x16x2_t a1 =
642       vtrnq_u8(vcombine_u8(in[2], in[6]), vcombine_u8(in[3], in[7]));
643 
644   const uint16x8x2_t b0 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[0]),
645                                     vreinterpretq_u16_u8(a1.val[0]));
646   const uint16x8x2_t b1 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[1]),
647                                     vreinterpretq_u16_u8(a1.val[1]));
648 
649   const uint32x4x2_t c0 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[0]),
650                                     vreinterpretq_u32_u16(b1.val[0]));
651   const uint32x4x2_t c1 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[1]),
652                                     vreinterpretq_u32_u16(b1.val[1]));
653 
654   out[0] = vreinterpretq_u8_u32(c0.val[0]);
655   out[1] = vreinterpretq_u8_u32(c1.val[0]);
656   out[2] = vreinterpretq_u8_u32(c0.val[1]);
657   out[3] = vreinterpretq_u8_u32(c1.val[1]);
658 }
659 
660 // Input:
661 // a[0]: 00 01 02 03 04 05 06 07
662 // a[1]: 10 11 12 13 14 15 16 17
663 // a[2]: 20 21 22 23 24 25 26 27
664 // a[3]: 30 31 32 33 34 35 36 37
665 // a[4]: 40 41 42 43 44 45 46 47
666 // a[5]: 50 51 52 53 54 55 56 57
667 // a[6]: 60 61 62 63 64 65 66 67
668 // a[7]: 70 71 72 73 74 75 76 77
669 
670 // Output:
671 // a[0]: 00 10 20 30 40 50 60 70
672 // a[1]: 01 11 21 31 41 51 61 71
673 // a[2]: 02 12 22 32 42 52 62 72
674 // a[3]: 03 13 23 33 43 53 63 73
675 // a[4]: 04 14 24 34 44 54 64 74
676 // a[5]: 05 15 25 35 45 55 65 75
677 // a[6]: 06 16 26 36 46 56 66 76
678 // a[7]: 07 17 27 37 47 57 67 77
Transpose8x8(int16x8_t a[8])679 inline void Transpose8x8(int16x8_t a[8]) {
680   const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]);
681   const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]);
682   const int16x8x2_t b2 = vtrnq_s16(a[4], a[5]);
683   const int16x8x2_t b3 = vtrnq_s16(a[6], a[7]);
684 
685   const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
686                                    vreinterpretq_s32_s16(b1.val[0]));
687   const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
688                                    vreinterpretq_s32_s16(b1.val[1]));
689   const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
690                                    vreinterpretq_s32_s16(b3.val[0]));
691   const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
692                                    vreinterpretq_s32_s16(b3.val[1]));
693 
694   const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]);
695   const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]);
696   const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]);
697   const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]);
698 
699   a[0] = d0.val[0];
700   a[1] = d1.val[0];
701   a[2] = d2.val[0];
702   a[3] = d3.val[0];
703   a[4] = d0.val[1];
704   a[5] = d1.val[1];
705   a[6] = d2.val[1];
706   a[7] = d3.val[1];
707 }
708 
709 // Unsigned.
Transpose8x8(uint16x8_t a[8])710 inline void Transpose8x8(uint16x8_t a[8]) {
711   const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]);
712   const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]);
713   const uint16x8x2_t b2 = vtrnq_u16(a[4], a[5]);
714   const uint16x8x2_t b3 = vtrnq_u16(a[6], a[7]);
715 
716   const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]),
717                                     vreinterpretq_u32_u16(b1.val[0]));
718   const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]),
719                                     vreinterpretq_u32_u16(b1.val[1]));
720   const uint32x4x2_t c2 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[0]),
721                                     vreinterpretq_u32_u16(b3.val[0]));
722   const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]),
723                                     vreinterpretq_u32_u16(b3.val[1]));
724 
725   const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c2.val[0]);
726   const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c3.val[0]);
727   const uint16x8x2_t d2 = VtrnqU64(c0.val[1], c2.val[1]);
728   const uint16x8x2_t d3 = VtrnqU64(c1.val[1], c3.val[1]);
729 
730   a[0] = d0.val[0];
731   a[1] = d1.val[0];
732   a[2] = d2.val[0];
733   a[3] = d3.val[0];
734   a[4] = d0.val[1];
735   a[5] = d1.val[1];
736   a[6] = d2.val[1];
737   a[7] = d3.val[1];
738 }
739 
740 // Input:
741 // a[0]: 00 01 02 03 04 05 06 07  80 81 82 83 84 85 86 87
742 // a[1]: 10 11 12 13 14 15 16 17  90 91 92 93 94 95 96 97
743 // a[2]: 20 21 22 23 24 25 26 27  a0 a1 a2 a3 a4 a5 a6 a7
744 // a[3]: 30 31 32 33 34 35 36 37  b0 b1 b2 b3 b4 b5 b6 b7
745 // a[4]: 40 41 42 43 44 45 46 47  c0 c1 c2 c3 c4 c5 c6 c7
746 // a[5]: 50 51 52 53 54 55 56 57  d0 d1 d2 d3 d4 d5 d6 d7
747 // a[6]: 60 61 62 63 64 65 66 67  e0 e1 e2 e3 e4 e5 e6 e7
748 // a[7]: 70 71 72 73 74 75 76 77  f0 f1 f2 f3 f4 f5 f6 f7
749 
750 // Output:
751 // a[0]: 00 10 20 30 40 50 60 70  80 90 a0 b0 c0 d0 e0 f0
752 // a[1]: 01 11 21 31 41 51 61 71  81 91 a1 b1 c1 d1 e1 f1
753 // a[2]: 02 12 22 32 42 52 62 72  82 92 a2 b2 c2 d2 e2 f2
754 // a[3]: 03 13 23 33 43 53 63 73  83 93 a3 b3 c3 d3 e3 f3
755 // a[4]: 04 14 24 34 44 54 64 74  84 94 a4 b4 c4 d4 e4 f4
756 // a[5]: 05 15 25 35 45 55 65 75  85 95 a5 b5 c5 d5 e5 f5
757 // a[6]: 06 16 26 36 46 56 66 76  86 96 a6 b6 c6 d6 e6 f6
758 // a[7]: 07 17 27 37 47 57 67 77  87 97 a7 b7 c7 d7 e7 f7
Transpose8x16(uint8x16_t a[8])759 inline void Transpose8x16(uint8x16_t a[8]) {
760   // b0.val[0]: 00 10 02 12 04 14 06 16  80 90 82 92 84 94 86 96
761   // b0.val[1]: 01 11 03 13 05 15 07 17  81 91 83 93 85 95 87 97
762   // b1.val[0]: 20 30 22 32 24 34 26 36  a0 b0 a2 b2 a4 b4 a6 b6
763   // b1.val[1]: 21 31 23 33 25 35 27 37  a1 b1 a3 b3 a5 b5 a7 b7
764   // b2.val[0]: 40 50 42 52 44 54 46 56  c0 d0 c2 d2 c4 d4 c6 d6
765   // b2.val[1]: 41 51 43 53 45 55 47 57  c1 d1 c3 d3 c5 d5 c7 d7
766   // b3.val[0]: 60 70 62 72 64 74 66 76  e0 f0 e2 f2 e4 f4 e6 f6
767   // b3.val[1]: 61 71 63 73 65 75 67 77  e1 f1 e3 f3 e5 f5 e7 f7
768   const uint8x16x2_t b0 = vtrnq_u8(a[0], a[1]);
769   const uint8x16x2_t b1 = vtrnq_u8(a[2], a[3]);
770   const uint8x16x2_t b2 = vtrnq_u8(a[4], a[5]);
771   const uint8x16x2_t b3 = vtrnq_u8(a[6], a[7]);
772 
773   // c0.val[0]: 00 10 20 30 04 14 24 34  80 90 a0 b0 84 94 a4 b4
774   // c0.val[1]: 02 12 22 32 06 16 26 36  82 92 a2 b2 86 96 a6 b6
775   // c1.val[0]: 01 11 21 31 05 15 25 35  81 91 a1 b1 85 95 a5 b5
776   // c1.val[1]: 03 13 23 33 07 17 27 37  83 93 a3 b3 87 97 a7 b7
777   // c2.val[0]: 40 50 60 70 44 54 64 74  c0 d0 e0 f0 c4 d4 e4 f4
778   // c2.val[1]: 42 52 62 72 46 56 66 76  c2 d2 e2 f2 c6 d6 e6 f6
779   // c3.val[0]: 41 51 61 71 45 55 65 75  c1 d1 e1 f1 c5 d5 e5 f5
780   // c3.val[1]: 43 53 63 73 47 57 67 77  c3 d3 e3 f3 c7 d7 e7 f7
781   const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
782                                     vreinterpretq_u16_u8(b1.val[0]));
783   const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
784                                     vreinterpretq_u16_u8(b1.val[1]));
785   const uint16x8x2_t c2 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[0]),
786                                     vreinterpretq_u16_u8(b3.val[0]));
787   const uint16x8x2_t c3 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[1]),
788                                     vreinterpretq_u16_u8(b3.val[1]));
789 
790   // d0.val[0]: 00 10 20 30 40 50 60 70  80 90 a0 b0 c0 d0 e0 f0
791   // d0.val[1]: 04 14 24 34 44 54 64 74  84 94 a4 b4 c4 d4 e4 f4
792   // d1.val[0]: 01 11 21 31 41 51 61 71  81 91 a1 b1 c1 d1 e1 f1
793   // d1.val[1]: 05 15 25 35 45 55 65 75  85 95 a5 b5 c5 d5 e5 f5
794   // d2.val[0]: 02 12 22 32 42 52 62 72  82 92 a2 b2 c2 d2 e2 f2
795   // d2.val[1]: 06 16 26 36 46 56 66 76  86 96 a6 b6 c6 d6 e6 f6
796   // d3.val[0]: 03 13 23 33 43 53 63 73  83 93 a3 b3 c3 d3 e3 f3
797   // d3.val[1]: 07 17 27 37 47 57 67 77  87 97 a7 b7 c7 d7 e7 f7
798   const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]),
799                                     vreinterpretq_u32_u16(c2.val[0]));
800   const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[0]),
801                                     vreinterpretq_u32_u16(c3.val[0]));
802   const uint32x4x2_t d2 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]),
803                                     vreinterpretq_u32_u16(c2.val[1]));
804   const uint32x4x2_t d3 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[1]),
805                                     vreinterpretq_u32_u16(c3.val[1]));
806 
807   a[0] = vreinterpretq_u8_u32(d0.val[0]);
808   a[1] = vreinterpretq_u8_u32(d1.val[0]);
809   a[2] = vreinterpretq_u8_u32(d2.val[0]);
810   a[3] = vreinterpretq_u8_u32(d3.val[0]);
811   a[4] = vreinterpretq_u8_u32(d0.val[1]);
812   a[5] = vreinterpretq_u8_u32(d1.val[1]);
813   a[6] = vreinterpretq_u8_u32(d2.val[1]);
814   a[7] = vreinterpretq_u8_u32(d3.val[1]);
815 }
816 
ZeroExtend(const uint8x8_t in)817 inline int16x8_t ZeroExtend(const uint8x8_t in) {
818   return vreinterpretq_s16_u16(vmovl_u8(in));
819 }
820 
821 }  // namespace dsp
822 }  // namespace libgav1
823 
824 #endif  // LIBGAV1_ENABLE_NEON
825 #endif  // LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
826