1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint_neon.h: optimized NEON specializations of the templates 16 // in fixedpoint.h. 17 18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ 19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ 20 21 #include <arm_neon.h> 22 23 namespace gemmlowp { 24 25 template <> 26 struct FixedPointRawTypeTraits<int32x4_t> { 27 typedef std::int32_t ScalarRawType; 28 static constexpr int kLanes = 4; 29 }; 30 31 template <> 32 struct FixedPointRawTypeTraits<int16x8_t> { 33 typedef std::int16_t ScalarRawType; 34 static constexpr int kLanes = 8; 35 }; 36 37 template <> 38 inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) { 39 return vandq_s32(a, b); 40 } 41 42 template <> 43 inline int16x8_t BitAnd(int16x8_t a, int16x8_t b) { 44 return vandq_s16(a, b); 45 } 46 47 template <> 48 inline int32x4_t BitOr(int32x4_t a, int32x4_t b) { 49 return vorrq_s32(a, b); 50 } 51 52 template <> 53 inline int16x8_t BitOr(int16x8_t a, int16x8_t b) { 54 return vorrq_s16(a, b); 55 } 56 57 template <> 58 inline int32x4_t BitXor(int32x4_t a, int32x4_t b) { 59 return veorq_s32(a, b); 60 } 61 62 template <> 63 inline int16x8_t BitXor(int16x8_t a, int16x8_t b) { 64 return veorq_s16(a, b); 65 } 66 67 template <> 68 inline int32x4_t BitNot(int32x4_t a) { 69 return veorq_s32(a, vdupq_n_s32(-1)); 70 } 71 72 template <> 73 inline int16x8_t BitNot(int16x8_t a) { 74 return veorq_s16(a, vdupq_n_s16(-1)); 75 } 76 77 template <> 78 inline int32x4_t Add(int32x4_t a, int32x4_t b) { 79 return vaddq_s32(a, b); 80 } 81 82 template <> 83 inline int16x8_t Add(int16x8_t a, int16x8_t b) { 84 return vaddq_s16(a, b); 85 } 86 87 template <> 88 inline int32x4_t Sub(int32x4_t a, int32x4_t b) { 89 return vsubq_s32(a, b); 90 } 91 92 template <> 93 inline int16x8_t Sub(int16x8_t a, int16x8_t b) { 94 return vsubq_s16(a, b); 95 } 96 97 template <> 98 inline int32x4_t Neg(int32x4_t a) { 99 return vnegq_s32(a); 100 } 101 102 template <> 103 inline int16x8_t Neg(int16x8_t a) { 104 return vnegq_s16(a); 105 } 106 107 template <> 108 inline int32x4_t ShiftLeft(int32x4_t a, int offset) { 109 return vshlq_s32(a, vdupq_n_s32(offset)); 110 } 111 112 template <> 113 inline int16x8_t ShiftLeft(int16x8_t a, int offset) { 114 return vshlq_s16(a, vdupq_n_s16(offset)); 115 } 116 117 template <> 118 inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) { 119 return vshlq_s32(a, offset); 120 } 121 122 template <> 123 inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) { 124 return vshlq_s16(a, offset); 125 } 126 127 template <> 128 inline int32x4_t ShiftRight(int32x4_t a, int offset) { 129 return vshlq_s32(a, vdupq_n_s32(-offset)); 130 } 131 132 template <> 133 inline int16x8_t ShiftRight(int16x8_t a, int offset) { 134 return vshlq_s16(a, vdupq_n_s16(-offset)); 135 } 136 137 template <> 138 inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val, 139 int32x4_t else_val) { 140 return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val); 141 } 142 143 template <> 144 inline int16x8_t SelectUsingMask(int16x8_t if_mask, int16x8_t then_val, 145 int16x8_t else_val) { 146 return vbslq_s16(vreinterpretq_u16_s16(if_mask), then_val, else_val); 147 } 148 149 template <> 150 inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) { 151 return vreinterpretq_s32_u32(vceqq_s32(a, b)); 152 } 153 154 template <> 155 inline int16x8_t MaskIfEqual(int16x8_t a, int16x8_t b) { 156 return vreinterpretq_s16_u16(vceqq_s16(a, b)); 157 } 158 159 template <> 160 inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) { 161 return BitNot(MaskIfEqual(a, b)); 162 } 163 164 template <> 165 inline int16x8_t MaskIfNotEqual(int16x8_t a, int16x8_t b) { 166 return BitNot(MaskIfEqual(a, b)); 167 } 168 169 template <> 170 inline int32x4_t MaskIfZero(int32x4_t a) { 171 return MaskIfEqual(a, vdupq_n_s32(0)); 172 } 173 174 template <> 175 inline int16x8_t MaskIfZero(int16x8_t a) { 176 return MaskIfEqual(a, vdupq_n_s16(0)); 177 } 178 179 template <> 180 inline int32x4_t MaskIfNonZero(int32x4_t a) { 181 return vreinterpretq_s32_u32(vtstq_s32(a, a)); 182 } 183 184 template <> 185 inline int16x8_t MaskIfNonZero(int16x8_t a) { 186 return vreinterpretq_s16_u16(vtstq_s16(a, a)); 187 } 188 189 template <> 190 inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) { 191 return vreinterpretq_s32_u32(vcgtq_s32(a, b)); 192 } 193 194 template <> 195 inline int16x8_t MaskIfGreaterThan(int16x8_t a, int16x8_t b) { 196 return vreinterpretq_s16_u16(vcgtq_s16(a, b)); 197 } 198 199 template <> 200 inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) { 201 return vreinterpretq_s32_u32(vcgeq_s32(a, b)); 202 } 203 204 template <> 205 inline int16x8_t MaskIfGreaterThanOrEqual(int16x8_t a, int16x8_t b) { 206 return vreinterpretq_s16_u16(vcgeq_s16(a, b)); 207 } 208 209 template <> 210 inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) { 211 return vreinterpretq_s32_u32(vcltq_s32(a, b)); 212 } 213 214 template <> 215 inline int16x8_t MaskIfLessThan(int16x8_t a, int16x8_t b) { 216 return vreinterpretq_s16_u16(vcltq_s16(a, b)); 217 } 218 219 template <> 220 inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) { 221 return vreinterpretq_s32_u32(vcleq_s32(a, b)); 222 } 223 224 template <> 225 inline int16x8_t MaskIfLessThanOrEqual(int16x8_t a, int16x8_t b) { 226 return vreinterpretq_s16_u16(vcleq_s16(a, b)); 227 } 228 229 template <> 230 inline bool All(int32x4_t a) { 231 a = vandq_s32(a, vextq_s32(a, a, 1)); 232 a = vandq_s32(a, vextq_s32(a, a, 2)); 233 return vgetq_lane_s32(a, 0); 234 } 235 236 template <> 237 inline bool All(int16x8_t a) { 238 a = vandq_s16(a, vextq_s16(a, a, 1)); 239 a = vandq_s16(a, vextq_s16(a, a, 2)); 240 a = vandq_s16(a, vextq_s16(a, a, 4)); 241 return vgetq_lane_s16(a, 0); 242 } 243 244 template <> 245 inline bool Any(int32x4_t a) { 246 a = vorrq_s32(a, vextq_s32(a, a, 1)); 247 a = vorrq_s32(a, vextq_s32(a, a, 2)); 248 return vgetq_lane_s32(a, 0); 249 } 250 251 template <> 252 inline bool Any(int16x8_t a) { 253 a = vorrq_s16(a, vextq_s16(a, a, 1)); 254 a = vorrq_s16(a, vextq_s16(a, a, 2)); 255 a = vorrq_s16(a, vextq_s16(a, a, 4)); 256 return vgetq_lane_s16(a, 0); 257 } 258 259 template <> 260 inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) { 261 return vrhaddq_s32(a, b); 262 } 263 264 template <> 265 inline int16x8_t RoundingHalfSum(int16x8_t a, int16x8_t b) { 266 return vrhaddq_s16(a, b); 267 } 268 269 template <> 270 inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) { 271 return vqrdmulhq_s32(a, b); 272 } 273 274 template <> 275 inline int16x8_t SaturatingRoundingDoublingHighMul(int16x8_t a, int16x8_t b) { 276 return vqrdmulhq_s16(a, b); 277 } 278 279 template <> 280 inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) { 281 const int32x4_t shift_vec = vdupq_n_s32(-exponent); 282 const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); 283 const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); 284 return vrshlq_s32(fixed_up_x, shift_vec); 285 } 286 287 template <> 288 inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) { 289 const int16x8_t shift_vec = vdupq_n_s16(-exponent); 290 const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15); 291 const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); 292 return vrshlq_s16(fixed_up_x, shift_vec); 293 } 294 295 template <> 296 inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) { 297 const int32x4_t shift_vec = vnegq_s32(exponent); 298 const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); 299 const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); 300 return vrshlq_s32(fixed_up_x, shift_vec); 301 } 302 303 template <> 304 inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) { 305 const int16x8_t shift_vec = vnegq_s16(exponent); 306 const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15); 307 const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); 308 return vrshlq_s16(fixed_up_x, shift_vec); 309 } 310 311 template <int Exponent> 312 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> { 313 static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); } 314 }; 315 316 template <int Exponent> 317 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> { 318 static int32x4_t eval(int32x4_t x) { 319 const int32x4_t fixup = vshrq_n_s32(x, 31); 320 const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); 321 return vrshrq_n_s32(fixed_up_x, -Exponent); 322 } 323 }; 324 325 template <int Exponent> 326 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, 1> { 327 static int16x8_t eval(int16x8_t x) { return vqshlq_n_s16(x, Exponent); } 328 }; 329 330 template <int Exponent> 331 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, -1> { 332 static int16x8_t eval(int16x8_t x) { 333 const int16x8_t fixup = vshrq_n_s16(x, 15); 334 const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); 335 return vrshrq_n_s16(fixed_up_x, -Exponent); 336 } 337 }; 338 339 template <> 340 inline int32x4_t Dup<int32x4_t>(std::int32_t x) { 341 return vdupq_n_s32(x); 342 } 343 344 template <> 345 inline int16x8_t Dup<int16x8_t>(std::int16_t x) { 346 return vdupq_n_s16(x); 347 } 348 349 // So far this is only needed for int16. 350 template <> 351 inline int16x8_t SaturatingAdd(int16x8_t a, int16x8_t b) { 352 return vqaddq_s16(a, b); 353 } 354 355 } // end namespace gemmlowp 356 357 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ 358