1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint_neon.h: optimized NEON specializations of the templates 16 // in fixedpoint.h. 17 18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ 19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ 20 21 #include <arm_neon.h> 22 23 namespace gemmlowp { 24 25 template <> 26 struct FixedPointRawTypeTraits<int32x4_t> { 27 typedef std::int32_t ScalarRawType; 28 static const int kLanes = 4; 29 }; 30 31 template <> 32 struct FixedPointRawTypeTraits<int16x8_t> { 33 typedef std::int16_t ScalarRawType; 34 static const int kLanes = 8; 35 }; 36 37 template <> 38 inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) { 39 return vandq_s32(a, b); 40 } 41 42 template <> 43 inline int16x8_t BitAnd(int16x8_t a, int16x8_t b) { 44 return vandq_s16(a, b); 45 } 46 47 template <> 48 inline int32x4_t BitOr(int32x4_t a, int32x4_t b) { 49 return vorrq_s32(a, b); 50 } 51 52 template <> 53 inline int16x8_t BitOr(int16x8_t a, int16x8_t b) { 54 return vorrq_s16(a, b); 55 } 56 57 template <> 58 inline int32x4_t BitXor(int32x4_t a, int32x4_t b) { 59 return veorq_s32(a, b); 60 } 61 62 template <> 63 inline int16x8_t BitXor(int16x8_t a, int16x8_t b) { 64 return veorq_s16(a, b); 65 } 66 67 template <> 68 inline int32x4_t BitNot(int32x4_t a) { 69 return veorq_s32(a, vdupq_n_s32(-1)); 70 } 71 72 template <> 73 inline int16x8_t BitNot(int16x8_t a) { 74 return veorq_s16(a, vdupq_n_s16(-1)); 75 } 76 77 template <> 78 inline int32x4_t Add(int32x4_t a, int32x4_t b) { 79 return vaddq_s32(a, b); 80 } 81 82 template <> 83 inline int16x8_t Add(int16x8_t a, int16x8_t b) { 84 return vaddq_s16(a, b); 85 } 86 87 template <> 88 inline int32x4_t Sub(int32x4_t a, int32x4_t b) { 89 return vsubq_s32(a, b); 90 } 91 92 template <> 93 inline int16x8_t Sub(int16x8_t a, int16x8_t b) { 94 return vsubq_s16(a, b); 95 } 96 97 template <> 98 inline int32x4_t Neg(int32x4_t a) { 99 return vnegq_s32(a); 100 } 101 102 template <> 103 inline int16x8_t Neg(int16x8_t a) { 104 return vnegq_s16(a); 105 } 106 107 template <> 108 inline int32x4_t ShiftLeft(int32x4_t a, int offset) { 109 return vshlq_s32(a, vdupq_n_s32(offset)); 110 } 111 112 template <> 113 inline int16x8_t ShiftLeft(int16x8_t a, int offset) { 114 return vshlq_s16(a, vdupq_n_s16(offset)); 115 } 116 117 template <> 118 inline int32x4_t ShiftRight(int32x4_t a, int offset) { 119 return vshlq_s32(a, vdupq_n_s32(-offset)); 120 } 121 122 template <> 123 inline int16x8_t ShiftRight(int16x8_t a, int offset) { 124 return vshlq_s16(a, vdupq_n_s16(-offset)); 125 } 126 127 template <> 128 inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val, 129 int32x4_t else_val) { 130 return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val); 131 } 132 133 template <> 134 inline int16x8_t SelectUsingMask(int16x8_t if_mask, int16x8_t then_val, 135 int16x8_t else_val) { 136 return vbslq_s16(vreinterpretq_u16_s16(if_mask), then_val, else_val); 137 } 138 139 template <> 140 inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) { 141 return vreinterpretq_s32_u32(vceqq_s32(a, b)); 142 } 143 144 template <> 145 inline int16x8_t MaskIfEqual(int16x8_t a, int16x8_t b) { 146 return vreinterpretq_s16_u16(vceqq_s16(a, b)); 147 } 148 149 template <> 150 inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) { 151 return BitNot(MaskIfEqual(a, b)); 152 } 153 154 template <> 155 inline int16x8_t MaskIfNotEqual(int16x8_t a, int16x8_t b) { 156 return BitNot(MaskIfEqual(a, b)); 157 } 158 159 template <> 160 inline int32x4_t MaskIfZero(int32x4_t a) { 161 return MaskIfEqual(a, vdupq_n_s32(0)); 162 } 163 164 template <> 165 inline int16x8_t MaskIfZero(int16x8_t a) { 166 return MaskIfEqual(a, vdupq_n_s16(0)); 167 } 168 169 template <> 170 inline int32x4_t MaskIfNonZero(int32x4_t a) { 171 return vreinterpretq_s32_u32(vtstq_s32(a, a)); 172 } 173 174 template <> 175 inline int16x8_t MaskIfNonZero(int16x8_t a) { 176 return vreinterpretq_s16_u16(vtstq_s16(a, a)); 177 } 178 179 template <> 180 inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) { 181 return vreinterpretq_s32_u32(vcgtq_s32(a, b)); 182 } 183 184 template <> 185 inline int16x8_t MaskIfGreaterThan(int16x8_t a, int16x8_t b) { 186 return vreinterpretq_s16_u16(vcgtq_s16(a, b)); 187 } 188 189 template <> 190 inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) { 191 return vreinterpretq_s32_u32(vcgeq_s32(a, b)); 192 } 193 194 template <> 195 inline int16x8_t MaskIfGreaterThanOrEqual(int16x8_t a, int16x8_t b) { 196 return vreinterpretq_s16_u16(vcgeq_s16(a, b)); 197 } 198 199 template <> 200 inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) { 201 return vreinterpretq_s32_u32(vcltq_s32(a, b)); 202 } 203 204 template <> 205 inline int16x8_t MaskIfLessThan(int16x8_t a, int16x8_t b) { 206 return vreinterpretq_s16_u16(vcltq_s16(a, b)); 207 } 208 209 template <> 210 inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) { 211 return vreinterpretq_s32_u32(vcleq_s32(a, b)); 212 } 213 214 template <> 215 inline int16x8_t MaskIfLessThanOrEqual(int16x8_t a, int16x8_t b) { 216 return vreinterpretq_s16_u16(vcleq_s16(a, b)); 217 } 218 219 template <> 220 inline bool All(int32x4_t a) { 221 a = vandq_s32(a, vextq_s32(a, a, 1)); 222 a = vandq_s32(a, vextq_s32(a, a, 2)); 223 return vgetq_lane_s32(a, 0); 224 } 225 226 template <> 227 inline bool All(int16x8_t a) { 228 a = vandq_s16(a, vextq_s16(a, a, 1)); 229 a = vandq_s16(a, vextq_s16(a, a, 2)); 230 a = vandq_s16(a, vextq_s16(a, a, 4)); 231 return vgetq_lane_s16(a, 0); 232 } 233 234 template <> 235 inline bool Any(int32x4_t a) { 236 a = vorrq_s32(a, vextq_s32(a, a, 1)); 237 a = vorrq_s32(a, vextq_s32(a, a, 2)); 238 return vgetq_lane_s32(a, 0); 239 } 240 241 template <> 242 inline bool Any(int16x8_t a) { 243 a = vorrq_s16(a, vextq_s16(a, a, 1)); 244 a = vorrq_s16(a, vextq_s16(a, a, 2)); 245 a = vorrq_s16(a, vextq_s16(a, a, 4)); 246 return vgetq_lane_s16(a, 0); 247 } 248 249 template <> 250 inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) { 251 return vrhaddq_s32(a, b); 252 } 253 254 template <> 255 inline int16x8_t RoundingHalfSum(int16x8_t a, int16x8_t b) { 256 return vrhaddq_s16(a, b); 257 } 258 259 template <> 260 inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) { 261 return vqrdmulhq_s32(a, b); 262 } 263 264 template <> 265 inline int16x8_t SaturatingRoundingDoublingHighMul(int16x8_t a, int16x8_t b) { 266 return vqrdmulhq_s16(a, b); 267 } 268 269 template <> 270 inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) { 271 const int32x4_t shift_vec = vdupq_n_s32(-exponent); 272 const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); 273 const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); 274 return vrshlq_s32(fixed_up_x, shift_vec); 275 } 276 277 template <> 278 inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) { 279 const int16x8_t shift_vec = vdupq_n_s16(-exponent); 280 const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15); 281 const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); 282 return vrshlq_s16(fixed_up_x, shift_vec); 283 } 284 285 template <int Exponent> 286 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> { 287 static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); } 288 }; 289 290 template <int Exponent> 291 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> { 292 static int32x4_t eval(int32x4_t x) { 293 const int32x4_t fixup = vshrq_n_s32(x, 31); 294 const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); 295 return vrshrq_n_s32(fixed_up_x, -Exponent); 296 } 297 }; 298 299 template <int Exponent> 300 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, 1> { 301 static int16x8_t eval(int16x8_t x) { return vqshlq_n_s16(x, Exponent); } 302 }; 303 304 template <int Exponent> 305 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, -1> { 306 static int16x8_t eval(int16x8_t x) { 307 const int16x8_t fixup = vshrq_n_s16(x, 15); 308 const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); 309 return vrshrq_n_s16(fixed_up_x, -Exponent); 310 } 311 }; 312 313 template <> 314 inline int32x4_t Dup<int32x4_t>(std::int32_t x) { 315 return vdupq_n_s32(x); 316 } 317 318 template <> 319 inline int16x8_t Dup<int16x8_t>(std::int16_t x) { 320 return vdupq_n_s16(x); 321 } 322 323 // So far this is only needed for int16. 324 template <> 325 inline int16x8_t SaturatingAdd(int16x8_t a, int16x8_t b) { 326 return vqaddq_s16(a, b); 327 } 328 329 } // end namespace gemmlowp 330 331 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ 332