1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // fixedpoint_neon.h: optimized NEON specializations of the templates
16 // in fixedpoint.h.
17 
18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
20 
21 #include <arm_neon.h>
22 
23 namespace gemmlowp {
24 
25 template <>
26 struct FixedPointRawTypeTraits<int32x4_t> {
27   typedef std::int32_t ScalarRawType;
28   static const int kLanes = 4;
29 };
30 
31 template <>
32 struct FixedPointRawTypeTraits<int16x8_t> {
33   typedef std::int16_t ScalarRawType;
34   static const int kLanes = 8;
35 };
36 
37 template <>
38 inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) {
39   return vandq_s32(a, b);
40 }
41 
42 template <>
43 inline int16x8_t BitAnd(int16x8_t a, int16x8_t b) {
44   return vandq_s16(a, b);
45 }
46 
47 template <>
48 inline int32x4_t BitOr(int32x4_t a, int32x4_t b) {
49   return vorrq_s32(a, b);
50 }
51 
52 template <>
53 inline int16x8_t BitOr(int16x8_t a, int16x8_t b) {
54   return vorrq_s16(a, b);
55 }
56 
57 template <>
58 inline int32x4_t BitXor(int32x4_t a, int32x4_t b) {
59   return veorq_s32(a, b);
60 }
61 
62 template <>
63 inline int16x8_t BitXor(int16x8_t a, int16x8_t b) {
64   return veorq_s16(a, b);
65 }
66 
67 template <>
68 inline int32x4_t BitNot(int32x4_t a) {
69   return veorq_s32(a, vdupq_n_s32(-1));
70 }
71 
72 template <>
73 inline int16x8_t BitNot(int16x8_t a) {
74   return veorq_s16(a, vdupq_n_s16(-1));
75 }
76 
77 template <>
78 inline int32x4_t Add(int32x4_t a, int32x4_t b) {
79   return vaddq_s32(a, b);
80 }
81 
82 template <>
83 inline int16x8_t Add(int16x8_t a, int16x8_t b) {
84   return vaddq_s16(a, b);
85 }
86 
87 template <>
88 inline int32x4_t Sub(int32x4_t a, int32x4_t b) {
89   return vsubq_s32(a, b);
90 }
91 
92 template <>
93 inline int16x8_t Sub(int16x8_t a, int16x8_t b) {
94   return vsubq_s16(a, b);
95 }
96 
97 template <>
98 inline int32x4_t Neg(int32x4_t a) {
99   return vnegq_s32(a);
100 }
101 
102 template <>
103 inline int16x8_t Neg(int16x8_t a) {
104   return vnegq_s16(a);
105 }
106 
107 template <>
108 inline int32x4_t ShiftLeft(int32x4_t a, int offset) {
109   return vshlq_s32(a, vdupq_n_s32(offset));
110 }
111 
112 template <>
113 inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
114   return vshlq_s16(a, vdupq_n_s16(offset));
115 }
116 
117 template <>
118 inline int32x4_t ShiftRight(int32x4_t a, int offset) {
119   return vshlq_s32(a, vdupq_n_s32(-offset));
120 }
121 
122 template <>
123 inline int16x8_t ShiftRight(int16x8_t a, int offset) {
124   return vshlq_s16(a, vdupq_n_s16(-offset));
125 }
126 
127 template <>
128 inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val,
129                                  int32x4_t else_val) {
130   return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val);
131 }
132 
133 template <>
134 inline int16x8_t SelectUsingMask(int16x8_t if_mask, int16x8_t then_val,
135                                  int16x8_t else_val) {
136   return vbslq_s16(vreinterpretq_u16_s16(if_mask), then_val, else_val);
137 }
138 
139 template <>
140 inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) {
141   return vreinterpretq_s32_u32(vceqq_s32(a, b));
142 }
143 
144 template <>
145 inline int16x8_t MaskIfEqual(int16x8_t a, int16x8_t b) {
146   return vreinterpretq_s16_u16(vceqq_s16(a, b));
147 }
148 
149 template <>
150 inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) {
151   return BitNot(MaskIfEqual(a, b));
152 }
153 
154 template <>
155 inline int16x8_t MaskIfNotEqual(int16x8_t a, int16x8_t b) {
156   return BitNot(MaskIfEqual(a, b));
157 }
158 
159 template <>
160 inline int32x4_t MaskIfZero(int32x4_t a) {
161   return MaskIfEqual(a, vdupq_n_s32(0));
162 }
163 
164 template <>
165 inline int16x8_t MaskIfZero(int16x8_t a) {
166   return MaskIfEqual(a, vdupq_n_s16(0));
167 }
168 
169 template <>
170 inline int32x4_t MaskIfNonZero(int32x4_t a) {
171   return vreinterpretq_s32_u32(vtstq_s32(a, a));
172 }
173 
174 template <>
175 inline int16x8_t MaskIfNonZero(int16x8_t a) {
176   return vreinterpretq_s16_u16(vtstq_s16(a, a));
177 }
178 
179 template <>
180 inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) {
181   return vreinterpretq_s32_u32(vcgtq_s32(a, b));
182 }
183 
184 template <>
185 inline int16x8_t MaskIfGreaterThan(int16x8_t a, int16x8_t b) {
186   return vreinterpretq_s16_u16(vcgtq_s16(a, b));
187 }
188 
189 template <>
190 inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) {
191   return vreinterpretq_s32_u32(vcgeq_s32(a, b));
192 }
193 
194 template <>
195 inline int16x8_t MaskIfGreaterThanOrEqual(int16x8_t a, int16x8_t b) {
196   return vreinterpretq_s16_u16(vcgeq_s16(a, b));
197 }
198 
199 template <>
200 inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) {
201   return vreinterpretq_s32_u32(vcltq_s32(a, b));
202 }
203 
204 template <>
205 inline int16x8_t MaskIfLessThan(int16x8_t a, int16x8_t b) {
206   return vreinterpretq_s16_u16(vcltq_s16(a, b));
207 }
208 
209 template <>
210 inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) {
211   return vreinterpretq_s32_u32(vcleq_s32(a, b));
212 }
213 
214 template <>
215 inline int16x8_t MaskIfLessThanOrEqual(int16x8_t a, int16x8_t b) {
216   return vreinterpretq_s16_u16(vcleq_s16(a, b));
217 }
218 
219 template <>
220 inline bool All(int32x4_t a) {
221   a = vandq_s32(a, vextq_s32(a, a, 1));
222   a = vandq_s32(a, vextq_s32(a, a, 2));
223   return vgetq_lane_s32(a, 0);
224 }
225 
226 template <>
227 inline bool All(int16x8_t a) {
228   a = vandq_s16(a, vextq_s16(a, a, 1));
229   a = vandq_s16(a, vextq_s16(a, a, 2));
230   a = vandq_s16(a, vextq_s16(a, a, 4));
231   return vgetq_lane_s16(a, 0);
232 }
233 
234 template <>
235 inline bool Any(int32x4_t a) {
236   a = vorrq_s32(a, vextq_s32(a, a, 1));
237   a = vorrq_s32(a, vextq_s32(a, a, 2));
238   return vgetq_lane_s32(a, 0);
239 }
240 
241 template <>
242 inline bool Any(int16x8_t a) {
243   a = vorrq_s16(a, vextq_s16(a, a, 1));
244   a = vorrq_s16(a, vextq_s16(a, a, 2));
245   a = vorrq_s16(a, vextq_s16(a, a, 4));
246   return vgetq_lane_s16(a, 0);
247 }
248 
249 template <>
250 inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) {
251   return vrhaddq_s32(a, b);
252 }
253 
254 template <>
255 inline int16x8_t RoundingHalfSum(int16x8_t a, int16x8_t b) {
256   return vrhaddq_s16(a, b);
257 }
258 
259 template <>
260 inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) {
261   return vqrdmulhq_s32(a, b);
262 }
263 
264 template <>
265 inline int16x8_t SaturatingRoundingDoublingHighMul(int16x8_t a, int16x8_t b) {
266   return vqrdmulhq_s16(a, b);
267 }
268 
269 template <>
270 inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) {
271   const int32x4_t shift_vec = vdupq_n_s32(-exponent);
272   const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
273   const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
274   return vrshlq_s32(fixed_up_x, shift_vec);
275 }
276 
277 template <>
278 inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
279   const int16x8_t shift_vec = vdupq_n_s16(-exponent);
280   const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
281   const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
282   return vrshlq_s16(fixed_up_x, shift_vec);
283 }
284 
285 template <int Exponent>
286 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
287   static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
288 };
289 
290 template <int Exponent>
291 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> {
292   static int32x4_t eval(int32x4_t x) {
293     const int32x4_t fixup = vshrq_n_s32(x, 31);
294     const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
295     return vrshrq_n_s32(fixed_up_x, -Exponent);
296   }
297 };
298 
299 template <int Exponent>
300 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, 1> {
301   static int16x8_t eval(int16x8_t x) { return vqshlq_n_s16(x, Exponent); }
302 };
303 
304 template <int Exponent>
305 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, -1> {
306   static int16x8_t eval(int16x8_t x) {
307     const int16x8_t fixup = vshrq_n_s16(x, 15);
308     const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
309     return vrshrq_n_s16(fixed_up_x, -Exponent);
310   }
311 };
312 
313 template <>
314 inline int32x4_t Dup<int32x4_t>(std::int32_t x) {
315   return vdupq_n_s32(x);
316 }
317 
318 template <>
319 inline int16x8_t Dup<int16x8_t>(std::int16_t x) {
320   return vdupq_n_s16(x);
321 }
322 
323 // So far this is only needed for int16.
324 template <>
325 inline int16x8_t SaturatingAdd(int16x8_t a, int16x8_t b) {
326   return vqaddq_s16(a, b);
327 }
328 
329 }  // end namespace gemmlowp
330 
331 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
332