1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // fixedpoint_neon.h: optimized NEON specializations of the templates
16 // in fixedpoint.h.
17 
18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
20 
21 #include <arm_neon.h>
22 
23 namespace gemmlowp {
24 
25 template <>
26 struct FixedPointRawTypeTraits<int32x4_t> {
27   typedef std::int32_t ScalarRawType;
28   static constexpr int kLanes = 4;
29 };
30 
31 template <>
32 struct FixedPointRawTypeTraits<int16x8_t> {
33   typedef std::int16_t ScalarRawType;
34   static constexpr int kLanes = 8;
35 };
36 
37 template <>
38 inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) {
39   return vandq_s32(a, b);
40 }
41 
42 template <>
43 inline int16x8_t BitAnd(int16x8_t a, int16x8_t b) {
44   return vandq_s16(a, b);
45 }
46 
47 template <>
48 inline int32x4_t BitOr(int32x4_t a, int32x4_t b) {
49   return vorrq_s32(a, b);
50 }
51 
52 template <>
53 inline int16x8_t BitOr(int16x8_t a, int16x8_t b) {
54   return vorrq_s16(a, b);
55 }
56 
57 template <>
58 inline int32x4_t BitXor(int32x4_t a, int32x4_t b) {
59   return veorq_s32(a, b);
60 }
61 
62 template <>
63 inline int16x8_t BitXor(int16x8_t a, int16x8_t b) {
64   return veorq_s16(a, b);
65 }
66 
67 template <>
68 inline int32x4_t BitNot(int32x4_t a) {
69   return veorq_s32(a, vdupq_n_s32(-1));
70 }
71 
72 template <>
73 inline int16x8_t BitNot(int16x8_t a) {
74   return veorq_s16(a, vdupq_n_s16(-1));
75 }
76 
77 template <>
78 inline int32x4_t Add(int32x4_t a, int32x4_t b) {
79   return vaddq_s32(a, b);
80 }
81 
82 template <>
83 inline int16x8_t Add(int16x8_t a, int16x8_t b) {
84   return vaddq_s16(a, b);
85 }
86 
87 template <>
88 inline int32x4_t Sub(int32x4_t a, int32x4_t b) {
89   return vsubq_s32(a, b);
90 }
91 
92 template <>
93 inline int16x8_t Sub(int16x8_t a, int16x8_t b) {
94   return vsubq_s16(a, b);
95 }
96 
97 template <>
98 inline int32x4_t Neg(int32x4_t a) {
99   return vnegq_s32(a);
100 }
101 
102 template <>
103 inline int16x8_t Neg(int16x8_t a) {
104   return vnegq_s16(a);
105 }
106 
107 template <>
108 inline int32x4_t ShiftLeft(int32x4_t a, int offset) {
109   return vshlq_s32(a, vdupq_n_s32(offset));
110 }
111 
112 template <>
113 inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
114   return vshlq_s16(a, vdupq_n_s16(offset));
115 }
116 
117 template <>
118 inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) {
119   return vshlq_s32(a, offset);
120 }
121 
122 template <>
123 inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) {
124   return vshlq_s16(a, offset);
125 }
126 
127 template <>
128 inline int32x4_t ShiftRight(int32x4_t a, int offset) {
129   return vshlq_s32(a, vdupq_n_s32(-offset));
130 }
131 
132 template <>
133 inline int16x8_t ShiftRight(int16x8_t a, int offset) {
134   return vshlq_s16(a, vdupq_n_s16(-offset));
135 }
136 
137 template <>
138 inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val,
139                                  int32x4_t else_val) {
140   return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val);
141 }
142 
143 template <>
144 inline int16x8_t SelectUsingMask(int16x8_t if_mask, int16x8_t then_val,
145                                  int16x8_t else_val) {
146   return vbslq_s16(vreinterpretq_u16_s16(if_mask), then_val, else_val);
147 }
148 
149 template <>
150 inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) {
151   return vreinterpretq_s32_u32(vceqq_s32(a, b));
152 }
153 
154 template <>
155 inline int16x8_t MaskIfEqual(int16x8_t a, int16x8_t b) {
156   return vreinterpretq_s16_u16(vceqq_s16(a, b));
157 }
158 
159 template <>
160 inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) {
161   return BitNot(MaskIfEqual(a, b));
162 }
163 
164 template <>
165 inline int16x8_t MaskIfNotEqual(int16x8_t a, int16x8_t b) {
166   return BitNot(MaskIfEqual(a, b));
167 }
168 
169 template <>
170 inline int32x4_t MaskIfZero(int32x4_t a) {
171   return MaskIfEqual(a, vdupq_n_s32(0));
172 }
173 
174 template <>
175 inline int16x8_t MaskIfZero(int16x8_t a) {
176   return MaskIfEqual(a, vdupq_n_s16(0));
177 }
178 
179 template <>
180 inline int32x4_t MaskIfNonZero(int32x4_t a) {
181   return vreinterpretq_s32_u32(vtstq_s32(a, a));
182 }
183 
184 template <>
185 inline int16x8_t MaskIfNonZero(int16x8_t a) {
186   return vreinterpretq_s16_u16(vtstq_s16(a, a));
187 }
188 
189 template <>
190 inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) {
191   return vreinterpretq_s32_u32(vcgtq_s32(a, b));
192 }
193 
194 template <>
195 inline int16x8_t MaskIfGreaterThan(int16x8_t a, int16x8_t b) {
196   return vreinterpretq_s16_u16(vcgtq_s16(a, b));
197 }
198 
199 template <>
200 inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) {
201   return vreinterpretq_s32_u32(vcgeq_s32(a, b));
202 }
203 
204 template <>
205 inline int16x8_t MaskIfGreaterThanOrEqual(int16x8_t a, int16x8_t b) {
206   return vreinterpretq_s16_u16(vcgeq_s16(a, b));
207 }
208 
209 template <>
210 inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) {
211   return vreinterpretq_s32_u32(vcltq_s32(a, b));
212 }
213 
214 template <>
215 inline int16x8_t MaskIfLessThan(int16x8_t a, int16x8_t b) {
216   return vreinterpretq_s16_u16(vcltq_s16(a, b));
217 }
218 
219 template <>
220 inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) {
221   return vreinterpretq_s32_u32(vcleq_s32(a, b));
222 }
223 
224 template <>
225 inline int16x8_t MaskIfLessThanOrEqual(int16x8_t a, int16x8_t b) {
226   return vreinterpretq_s16_u16(vcleq_s16(a, b));
227 }
228 
229 template <>
230 inline bool All(int32x4_t a) {
231   a = vandq_s32(a, vextq_s32(a, a, 1));
232   a = vandq_s32(a, vextq_s32(a, a, 2));
233   return vgetq_lane_s32(a, 0);
234 }
235 
236 template <>
237 inline bool All(int16x8_t a) {
238   a = vandq_s16(a, vextq_s16(a, a, 1));
239   a = vandq_s16(a, vextq_s16(a, a, 2));
240   a = vandq_s16(a, vextq_s16(a, a, 4));
241   return vgetq_lane_s16(a, 0);
242 }
243 
244 template <>
245 inline bool Any(int32x4_t a) {
246   a = vorrq_s32(a, vextq_s32(a, a, 1));
247   a = vorrq_s32(a, vextq_s32(a, a, 2));
248   return vgetq_lane_s32(a, 0);
249 }
250 
251 template <>
252 inline bool Any(int16x8_t a) {
253   a = vorrq_s16(a, vextq_s16(a, a, 1));
254   a = vorrq_s16(a, vextq_s16(a, a, 2));
255   a = vorrq_s16(a, vextq_s16(a, a, 4));
256   return vgetq_lane_s16(a, 0);
257 }
258 
259 template <>
260 inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) {
261   return vrhaddq_s32(a, b);
262 }
263 
264 template <>
265 inline int16x8_t RoundingHalfSum(int16x8_t a, int16x8_t b) {
266   return vrhaddq_s16(a, b);
267 }
268 
269 template <>
270 inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) {
271   return vqrdmulhq_s32(a, b);
272 }
273 
274 template <>
275 inline int16x8_t SaturatingRoundingDoublingHighMul(int16x8_t a, int16x8_t b) {
276   return vqrdmulhq_s16(a, b);
277 }
278 
279 template <>
280 inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) {
281   const int32x4_t shift_vec = vdupq_n_s32(-exponent);
282   const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
283   const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
284   return vrshlq_s32(fixed_up_x, shift_vec);
285 }
286 
287 template <>
288 inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
289   const int16x8_t shift_vec = vdupq_n_s16(-exponent);
290   const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
291   const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
292   return vrshlq_s16(fixed_up_x, shift_vec);
293 }
294 
295 template <>
296 inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) {
297   const int32x4_t shift_vec = vnegq_s32(exponent);
298   const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
299   const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
300   return vrshlq_s32(fixed_up_x, shift_vec);
301 }
302 
303 template <>
304 inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) {
305   const int16x8_t shift_vec = vnegq_s16(exponent);
306   const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
307   const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
308   return vrshlq_s16(fixed_up_x, shift_vec);
309 }
310 
311 template <int Exponent>
312 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
313   static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
314 };
315 
316 template <int Exponent>
317 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> {
318   static int32x4_t eval(int32x4_t x) {
319     const int32x4_t fixup = vshrq_n_s32(x, 31);
320     const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
321     return vrshrq_n_s32(fixed_up_x, -Exponent);
322   }
323 };
324 
325 template <int Exponent>
326 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, 1> {
327   static int16x8_t eval(int16x8_t x) { return vqshlq_n_s16(x, Exponent); }
328 };
329 
330 template <int Exponent>
331 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, -1> {
332   static int16x8_t eval(int16x8_t x) {
333     const int16x8_t fixup = vshrq_n_s16(x, 15);
334     const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
335     return vrshrq_n_s16(fixed_up_x, -Exponent);
336   }
337 };
338 
339 template <>
340 inline int32x4_t Dup<int32x4_t>(std::int32_t x) {
341   return vdupq_n_s32(x);
342 }
343 
344 template <>
345 inline int16x8_t Dup<int16x8_t>(std::int16_t x) {
346   return vdupq_n_s16(x);
347 }
348 
349 // So far this is only needed for int16.
350 template <>
351 inline int16x8_t SaturatingAdd(int16x8_t a, int16x8_t b) {
352   return vqaddq_s16(a, b);
353 }
354 
355 }  // end namespace gemmlowp
356 
357 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
358