1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint.h: fixed-point arithmetic, with basic operations and 16 // a few math functions such as tanh. 17 18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 20 21 #include <cassert> 22 #include <limits> 23 24 #include "../internal/common.h" 25 26 namespace gemmlowp { 27 28 // Part 1: Low-level integer-arithmetic primitives. 29 // The implementations here are generic implementations valid for 30 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types 31 // (e.g. NEON int32x4_t) may be supported by providing 32 // specializations for them in separate files. 33 // 34 // The purpose of these primitives is two-fold: 35 // - They will be used to implement higher-level fixed-point 36 // abstractions, namely the FixedPoint class and its arithmetic 37 // operators. 38 // - They will be directly used to implement some more involved 39 // fixed-point computations, e.g. the fixed-point implementation 40 // of math functions such as tanh. 41 42 // Some compile-time traits around raw types to handle SIMD aspects: 43 // number of lanes, underlying scalar type. 44 template <typename tIntegerType> 45 struct FixedPointRawTypeTraits {}; 46 47 template <> 48 struct FixedPointRawTypeTraits<std::int32_t> { 49 typedef std::int32_t ScalarRawType; 50 static const int kLanes = 1; 51 }; 52 53 template <> 54 struct FixedPointRawTypeTraits<std::int16_t> { 55 typedef std::int16_t ScalarRawType; 56 static const int kLanes = 1; 57 }; 58 59 // Returns a SIMD value duplicating a scalar value across all lanes. 60 template <typename tRawType> 61 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 62 return x; 63 } 64 65 // Plain bit-wise AND 66 template <typename tIntegerType> 67 tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 68 return a & b; 69 } 70 71 // Plain bit-wise OR 72 template <typename tIntegerType> 73 tIntegerType BitOr(tIntegerType a, tIntegerType b) { 74 return a | b; 75 } 76 77 // Plain bit-wise XOR 78 template <typename tIntegerType> 79 tIntegerType BitXor(tIntegerType a, tIntegerType b) { 80 return a ^ b; 81 } 82 83 // Plain bit-wise NOT 84 template <typename tIntegerType> 85 tIntegerType BitNot(tIntegerType a) { 86 return ~a; 87 } 88 89 // Integer addition. Not saturating. Overflow is undefined behavior. 90 template <typename tIntegerType> 91 tIntegerType Add(tIntegerType a, tIntegerType b) { 92 return a + b; 93 } 94 95 // Integer subtraction. Not saturating. Overflow is undefined behavior. 96 template <typename tIntegerType> 97 tIntegerType Mul(tIntegerType a, tIntegerType b) { 98 return a * b; 99 } 100 101 template <typename tIntegerType> 102 tIntegerType Sub(tIntegerType a, tIntegerType b) { 103 return a - b; 104 } 105 106 // Integer unary negative. Not saturating. Overflow is undefined behavior. 107 template <typename tIntegerType> 108 tIntegerType Neg(tIntegerType a) { 109 return -a; 110 } 111 112 // Integer arithmetic left-shift, equivalent to multiplying with a 113 // power of two. Not saturating. Overflow is undefined behavior. 114 template <typename tIntegerType> 115 tIntegerType ShiftLeft(tIntegerType a, int offset) { 116 return a << offset; 117 } 118 119 // Integer arithmetic right-shift. Not rounding. 120 // Relying on implementation-defined, but in-practice-consistent, 121 // C++ compiler behavior. 122 template <typename tIntegerType> 123 tIntegerType ShiftRight(tIntegerType a, int offset) { 124 return a >> offset; 125 } 126 127 // Each bit of the result is set to the corresponding bit of either then_val or 128 // else_val depending on whether the corresponding bit of if_mask is set. 129 // Equivalent to the VBSL instruction in ARM NEON. 130 template <typename tIntegerType> 131 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 132 tIntegerType else_val) { 133 return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 134 } 135 136 // For each input scalar, the corresponding bits of the result are set if the 137 // input scalar is non-zero. 138 template <typename tIntegerType> 139 tIntegerType MaskIfNonZero(tIntegerType a) { 140 static const tIntegerType zero = 0; 141 return a ? BitNot(zero) : zero; 142 } 143 144 // For each input scalar, the corresponding bits of the result are set if the 145 // input scalar is zero. 146 template <typename tIntegerType> 147 tIntegerType MaskIfZero(tIntegerType a) { 148 return MaskIfNonZero<tIntegerType>(!a); 149 } 150 151 // For each pair of input scalars, the corresponding bits of the result are 152 // set if the input scalars are equal. 153 template <typename tIntegerType> 154 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 155 return MaskIfNonZero<tIntegerType>(a == b); 156 } 157 158 // For each pair of input scalars, the corresponding bits of the result are 159 // set if the input scalars are not equal. 160 template <typename tIntegerType> 161 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 162 return MaskIfNonZero<tIntegerType>(a != b); 163 } 164 165 // For each pair of input scalars, the corresponding bits of the result are 166 // set if the input scalars a, b satisfy a > b. 167 template <typename tIntegerType> 168 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 169 return MaskIfNonZero<tIntegerType>(a > b); 170 } 171 172 // For each pair of input scalars, the corresponding bits of the result are 173 // set if the input scalars a, b satisfy a >= b. 174 template <typename tIntegerType> 175 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 176 return MaskIfNonZero<tIntegerType>(a >= b); 177 } 178 179 // For each pair of input scalars, the corresponding bits of the result are 180 // set if the input scalars a, b satisfy a < b. 181 template <typename tIntegerType> 182 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 183 return MaskIfNonZero<tIntegerType>(a < b); 184 } 185 186 // For each pair of input scalars, the corresponding bits of the result are 187 // set if the input scalars a, b satisfy a <= b. 188 template <typename tIntegerType> 189 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 190 return MaskIfNonZero<tIntegerType>(a <= b); 191 } 192 193 // Returns true if all of the input scalars are nonzero. 194 // This function may currently assume that each of the input scalars has either 195 // all or none of its bits set. Otherwise, its behavior is currently undefined. 196 template <typename tIntegerType> 197 bool All(tIntegerType a) { 198 return a; 199 } 200 201 // Returns true if any of the input scalars are nonzero. 202 // This function may currently assume that each of the input scalars has either 203 // all or none of its bits set. Otherwise, its behavior is currently undefined. 204 template <typename tIntegerType> 205 bool Any(tIntegerType a) { 206 return a; 207 } 208 209 // Returns (a+b)/2, rounded to the nearest integer. 210 // Equivalent to VRHADD in the ARM NEON instruction set. 211 template <typename IntegerType> 212 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 213 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 214 return a; 215 } 216 217 template <> 218 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { 219 std::int64_t a64 = a; 220 std::int64_t b64 = b; 221 std::int64_t sum = a64 + b64; 222 std::int64_t sign = sum >= 0 ? 1 : -1; 223 return static_cast<std::int32_t>((sum + sign) / 2); 224 } 225 226 template <> 227 inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { 228 std::int32_t a32 = a; 229 std::int32_t b32 = b; 230 std::int32_t sum = a32 + b32; 231 std::int32_t sign = sum >= 0 ? 1 : -1; 232 return static_cast<std::int16_t>((sum + sign) / 2); 233 } 234 235 template <typename IntegerType> 236 IntegerType SaturatingAdd(IntegerType a, IntegerType b) { 237 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 238 return a; 239 } 240 241 // So far this is only needed for int16. 242 template <> 243 inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { 244 std::int32_t a32 = a; 245 std::int32_t b32 = b; 246 std::int32_t sum = a32 + b32; 247 return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum))); 248 } 249 250 // Returns a+b, saturating if the integers are 16bit or narrower, 251 // otherwise just a plain addition. 252 template <typename IntegerType, bool Is16Bit> 253 struct AddSaturatingIf16BitImpl { 254 static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } 255 }; 256 template <typename IntegerType> 257 struct AddSaturatingIf16BitImpl<IntegerType, true> { 258 static IntegerType Run(IntegerType a, IntegerType b) { 259 return SaturatingAdd(a, b); 260 } 261 }; 262 template <typename IntegerType> 263 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { 264 using ScalarType = 265 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 266 return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, 267 b); 268 } 269 270 // Returns the integer that represents the product of two fixed-point 271 // numbers, interpreting all integers as fixed-point values in the 272 // interval [-1, 1), rounding to the nearest value, and saturating 273 // -1 * -1 to the maximum value (since 1 is not in the half-open 274 // interval [-1, 1)). 275 // 276 // [The explanation below specializes to std::int32_t for example purpose.] 277 // 278 // The mapping between IntegerType and the interval [-1, 1) is unique and 279 // implied by IntegerType, which is assumed to be signed. For example, 280 // for IntegerType==std::int32_t, the mapping is 281 // real_value = integer_value / 2^31. 282 // So in this case, and leaving aside rounding and saturating, this 283 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to 284 // (a * b) / 2^31. 285 // 286 // The 'doubling' part in the name of this function comes from the fact that 287 // this operation is very close to a "multiply-high" operation, keeping only 288 // the top half bits, except that that would be effectively computing 289 // (a * b) / 2^32, 290 // so here we are computing 2x that, since 291 // 1/2^31 = 2 * 1/2^32. 292 // The idea is to use all of the available 32 bits in the destination int32 293 // value. 294 // 295 // [End of the explanation specializing to int32.] 296 // 297 // This is equivalent to the VQRDMULH instruction in ARM NEON. 298 template <typename IntegerType> 299 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { 300 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 301 return a; 302 } 303 304 // This function implements the same computation as the ARMv7 NEON VQRDMULH 305 // instruction. 306 template <> 307 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, 308 std::int32_t b) { 309 bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); 310 std::int64_t a_64(a); 311 std::int64_t b_64(b); 312 std::int64_t ab_64 = a_64 * b_64; 313 std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); 314 std::int32_t ab_x2_high32 = 315 static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); 316 return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; 317 } 318 319 template <> 320 inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, 321 std::int16_t b) { 322 bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min(); 323 std::int32_t a_32(a); 324 std::int32_t b_32(b); 325 std::int32_t ab_32 = a_32 * b_32; 326 std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); 327 std::int16_t ab_x2_high16 = 328 static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15)); 329 return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16; 330 } 331 332 // Correctly-rounded-to-nearest division by a power-of-two. 333 // Also known as a rounding arithmetic right shift. 334 template <typename IntegerType> 335 inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { 336 assert(exponent >= 0); 337 assert(exponent <= 31); 338 const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); 339 const IntegerType zero = Dup<IntegerType>(0); 340 const IntegerType one = Dup<IntegerType>(1); 341 const IntegerType remainder = BitAnd(x, mask); 342 const IntegerType threshold = 343 Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); 344 return Add(ShiftRight(x, exponent), 345 BitAnd(MaskIfGreaterThan(remainder, threshold), one)); 346 } 347 348 // Returns the product of a run-time integer value by a compile-time power 349 // of two, with either a positive exponent (equivalent to an arithmetic 350 // left shift, saturating) or a negative exponent (equivalent to an arithmetic 351 // right shift, rounding to nearest). 352 template <int Exponent, typename IntegerType, 353 int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 354 struct ImplSaturatingRoundingMultiplyByPOT {}; 355 356 template <int Exponent, typename IntegerType> 357 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 358 static IntegerType eval(IntegerType x) { return x; } 359 }; 360 361 template <int Exponent, typename IntegerType> 362 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { 363 static IntegerType eval(IntegerType x) { 364 using ScalarIntegerType = 365 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 366 const IntegerType min = 367 Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min()); 368 const IntegerType max = 369 Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max()); 370 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); 371 372 const std::int32_t threshold = 373 ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); 374 const IntegerType positive_mask = 375 MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); 376 const IntegerType negative_mask = 377 MaskIfLessThan(x, Dup<IntegerType>(-threshold)); 378 379 IntegerType result = ShiftLeft(x, Exponent); 380 result = SelectUsingMask(positive_mask, max, result); 381 result = SelectUsingMask(negative_mask, min, result); 382 return result; 383 } 384 }; 385 386 template <int Exponent, typename IntegerType> 387 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { 388 static IntegerType eval(IntegerType x) { 389 return RoundingDivideByPOT<IntegerType>(x, -Exponent); 390 } 391 }; 392 393 template <int Exponent, typename IntegerType> 394 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 395 return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 396 } 397 398 // Part 2: the FixedPoint class. 399 400 // A FixedPoint object represents a fixed-point value stored in the underlying 401 // integer type tRawType, if tRawType is a plain scalar integer type. 402 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which 403 // case a FixedPoint object represents a corresponding SIMD vector of fixed 404 // point values. 405 // 406 // tIntegerBits describes the range of the fixed-point format: if 407 // tIntegerBits == m then the range of representable values is the half-open 408 // interval [-2^m; 2^m) where the open boundary on the right side means that 409 // 2^m is not representable (how close the maximum representable value is to 410 // it, depends on bit-depth of tRawType). 411 // 412 // In "Q format notation", 413 // https://en.wikipedia.org/wiki/Q_(number_format) 414 // we are describing the format 415 // Qm.n 416 // where 417 // m = tIntegerBits 418 // and 419 // n = NumberOfBits(tRawType) - (m + 1) 420 // Note that the (m + 1) in the above line is because we adopt the convention 421 // that we count the integer bits exclusively of the sign bit; so (m + 1) is 422 // the total number of integer bits inclusive of the sign bit. 423 // 424 // Accordingly, the number of integral representable values in our range 425 // [-2^m ; 2^m) 426 // is equal to 2^(m+1). 427 template <typename tRawType, int tIntegerBits> 428 class FixedPoint { 429 public: 430 typedef tRawType RawType; 431 432 typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 433 typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 434 435 static const int kTotalBits = 8 * sizeof(ScalarRawType); 436 static const int kIntegerBits = tIntegerBits; 437 static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; 438 static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, 439 "bad IntegerBits"); 440 441 typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 442 443 static const ScalarRawType ScalarRawMin() { 444 return std::numeric_limits<ScalarRawType>::min(); 445 } 446 447 static const ScalarRawType ScalarRawMax() { 448 return std::numeric_limits<ScalarRawType>::max(); 449 } 450 451 static const ScalarRawType RawMin() { 452 return VectorFromScalar(ScalarRawMin()); 453 } 454 455 static const ScalarRawType RawMax() { 456 return VectorFromScalar(ScalarRawMax()); 457 } 458 459 static FixedPoint FromRaw(RawType x) { 460 FixedPoint retval; 461 retval.raw() = x; 462 return retval; 463 } 464 465 static FixedPoint FromScalarRaw(ScalarRawType x) { 466 FixedPoint retval; 467 retval.raw() = Dup<RawType>(x); 468 return retval; 469 } 470 471 static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 472 return FromScalarRaw(x.raw()); 473 } 474 475 template <int Exponent> 476 static FixedPoint ConstantPOT() { 477 static const int kOffset = kFractionalBits + Exponent; 478 static_assert( 479 kOffset < 31, 480 "Constant not exactly representable in this fixed-point format"); 481 return FromScalarRaw(ScalarRawType(1) << kOffset); 482 } 483 484 static FixedPoint Zero() { return FromScalarRaw(0); } 485 486 static FixedPoint One() { 487 return FromScalarRaw( 488 kIntegerBits == 0 489 ? ScalarRawMax() 490 : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); 491 } 492 493 static FixedPoint FromDouble(double x) { 494 const double min_bound = static_cast<double>(ScalarRawMin()); 495 const double max_bound = static_cast<double>(ScalarRawMax()); 496 return FromScalarRaw(static_cast<ScalarRawType>(std::min( 497 std::max(round(x * static_cast<double>(1ll << kFractionalBits)), 498 min_bound), 499 max_bound))); 500 } 501 502 RawType raw() const { return i_; } 503 RawType& raw() { return i_; } 504 505 private: 506 RawType i_; 507 }; 508 509 // Part 3: implementation of arithmetic operators for the 510 // FixedPoint class, and a few related functions. 511 512 // A FixedPoint multiplication is just a 513 // SaturatingRoundingDoublingHighMul operation on the underlying 514 // raw integer values. The IntegerBits simply add up, as is obvious 515 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). 516 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 517 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 518 FixedPoint<tRawType, tIntegerBits_a> a, 519 FixedPoint<tRawType, tIntegerBits_b> b) { 520 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 521 c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); 522 return c; 523 } 524 525 // Tweaking IntegerBits gives exact multiplication by a power of two. 526 template <int tExponent, typename tRawType, int tIntegerBits> 527 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 528 FixedPoint<tRawType, tIntegerBits> a) { 529 FixedPoint<tRawType, tExponent + tIntegerBits> c; 530 c.raw() = a.raw(); 531 return c; 532 } 533 534 // If we want to leave IntegerBits fixed, then multiplication 535 // by a power of two has to be saturating/rounding, not exact anymore. 536 template <int tExponent, typename tRawType, int tIntegerBits> 537 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 538 FixedPoint<tRawType, tIntegerBits> a) { 539 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 540 SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 541 } 542 543 // Generic arithmetic operators. 544 545 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 546 template <typename tRawType, int tIntegerBits> \ 547 FixedPoint<tRawType, tIntegerBits> FuncName( \ 548 FixedPoint<tRawType, tIntegerBits> a) { \ 549 return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 550 } 551 552 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 553 template <typename tRawType, int tIntegerBits> \ 554 FixedPoint<tRawType, tIntegerBits> FuncName( \ 555 FixedPoint<tRawType, tIntegerBits> a, \ 556 FixedPoint<tRawType, tIntegerBits> b) { \ 557 return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 558 ImplFuncName(a.raw(), b.raw())); \ 559 } 560 561 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 562 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 563 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 564 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 565 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 566 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 567 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 568 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 569 570 #undef MAKE_FIXEDPOINT_UNARY_FUNC 571 #undef MAKE_FIXEDPOINT_BINARY_FUNC 572 573 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 574 template <typename tRawType, int tIntegerBits> \ 575 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 576 return FuncName(a.raw()); \ 577 } 578 579 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 580 template <typename tRawType, int tIntegerBits> \ 581 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 582 FixedPoint<tRawType, tIntegerBits> b) { \ 583 return FuncName(a.raw(), b.raw()); \ 584 } 585 586 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 587 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 588 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 589 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 590 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 591 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 592 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 593 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 594 595 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 596 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 597 598 template <typename tRawType, int tIntegerBits> 599 FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 600 tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 601 FixedPoint<tRawType, tIntegerBits> else_val) { 602 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 603 SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 604 } 605 606 template <typename tRawType, int tIntegerBits> 607 bool operator==(FixedPoint<tRawType, tIntegerBits> a, 608 FixedPoint<tRawType, tIntegerBits> b) { 609 return All(MaskIfEqual(a.raw(), b.raw())); 610 } 611 612 template <typename tRawType, int tIntegerBits> 613 bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 614 FixedPoint<tRawType, tIntegerBits> b) { 615 return !(a == b); 616 } 617 618 template <typename tRawType, int tIntegerBits> 619 FixedPoint<tRawType, tIntegerBits> SaturatingAdd( 620 FixedPoint<tRawType, tIntegerBits> a, 621 FixedPoint<tRawType, tIntegerBits> b) { 622 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 623 SaturatingAdd(a.raw(), b.raw())); 624 } 625 626 template <typename tRawType, int tIntegerBits> 627 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit( 628 FixedPoint<tRawType, tIntegerBits> a, 629 FixedPoint<tRawType, tIntegerBits> b) { 630 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 631 AddSaturatingIf16Bit(a.raw(), b.raw())); 632 } 633 634 // Conversion to floating-point. 635 template <typename tRawType, int tIntegerBits> 636 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { 637 static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, 638 "not applicable to SIMD types"); 639 typedef FixedPoint<tRawType, tIntegerBits> F; 640 return x.raw() / static_cast<double>(1ll << F::kFractionalBits); 641 } 642 643 // Rescale changes the number of IntegerBits and updates the underlying 644 // raw integer value accordingly. 645 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 646 FixedPoint<tRawType, tIntegerBitsDst> Rescale( 647 FixedPoint<tRawType, tIntegerBitsSrc> x) { 648 static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 649 FixedPoint<tRawType, tIntegerBitsDst> result; 650 result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 651 return result; 652 } 653 654 // CheckedFixedPointConstant allows to specify fixed-point constants 655 // initialized as real numbers, in a way that does not compile floating-point 656 // arithmetic in production code, yet still checks agreement with the 657 // floating-point expressions when asserts are enabled. 658 // 659 // The raw integer value provided is always a int32, encoding a 32-bit 660 // fixed-point value, regardless of the actual Scalar type. This allows 661 // writing generic code that applies just as well to the 32-bit and 16-bit 662 // cases. In the 16-bit case, the raw integer value is internally 663 // rounding-shifted by 16 bits to the right. 664 template <typename FixedPointType> 665 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer( 666 std::int32_t int32_value) { 667 typedef typename FixedPointType::ScalarRawType ScalarRawType; 668 static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); 669 return static_cast<ScalarRawType>( 670 RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits)); 671 } 672 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 673 template <typename FixedPointType> 674 FixedPointType CheckedFixedPointConstant(std::int32_t raw_value, 675 double double_value) { 676 const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); 677 assert(result == FixedPointType::FromDouble(double_value)); 678 return result; 679 } 680 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 681 ScalarRawInt32Value, DoubleValue) \ 682 (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \ 683 gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 684 ScalarRawInt32Value), \ 685 DoubleValue)) 686 687 #else 688 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 689 ScalarRawInt32Value, DoubleValue) \ 690 (FixedPointType::FromScalarRaw( \ 691 gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 692 ScalarRawInt32Value))) 693 #endif 694 695 // Implementation of exponential function. 696 697 // Returns exp(x) for x in [-1/4, 0). 698 template <typename tRawType> 699 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( 700 FixedPoint<tRawType, 0> a) { 701 typedef FixedPoint<tRawType, 0> F; 702 const F constant_term = 703 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); 704 const F constant_1_over_3 = 705 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 706 // We're evaluating a Taylor expansion around -1/8, so we do the change of 707 // variable: x = a + 1/8. 708 // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 709 F x = a + F::template ConstantPOT<-3>(); 710 F x2 = x * x; 711 F x3 = x2 * x; 712 F x4 = x2 * x2; 713 F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 714 F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 715 SaturatingRoundingMultiplyByPOT<-1>( 716 ((x4_over_4 + x3) * constant_1_over_3) + x2); 717 return AddSaturatingIf16Bit( 718 constant_term, 719 constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); 720 } 721 722 // Returns exp(x) for x < 0. 723 template <typename tRawType, int tIntegerBits> 724 FixedPoint<tRawType, 0> exp_on_negative_values( 725 FixedPoint<tRawType, tIntegerBits> a) { 726 typedef FixedPoint<tRawType, tIntegerBits> InputF; 727 typedef FixedPoint<tRawType, 0> ResultF; 728 static const int kFractionalBits = InputF::kFractionalBits; 729 static const int kIntegerBits = InputF::kIntegerBits; 730 static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 731 InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 732 InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 733 ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( 734 Rescale<0>(a_mod_quarter_minus_one_quarter)); 735 tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 736 737 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 738 if (kIntegerBits > Exponent) { \ 739 const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ 740 ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ 741 static constexpr int kShiftAmount = \ 742 kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ 743 result = SelectUsingMask( \ 744 MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ 745 result * kMultiplier, result); \ 746 } 747 748 GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); 749 GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); 750 GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); 751 GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); 752 GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); 753 GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); 754 GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); 755 756 #undef GEMMLOWP_EXP_BARREL_SHIFTER 757 758 if (kIntegerBits > 5) { 759 static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0; 760 const InputF clamp = 761 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); 762 result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 763 } 764 765 result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 766 return result; 767 } 768 769 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). 770 771 // Returns (1 - x) / (1 + x) for x in (0, 1). 772 template <typename tRawType> 773 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( 774 FixedPoint<tRawType, 0> a) { 775 typedef FixedPoint<tRawType, 0> F0; 776 typedef FixedPoint<tRawType, 2> F2; 777 F0 half_denominator = RoundingHalfSum(a, F0::One()); 778 // Newton-Raphson division 779 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 780 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 781 const F2 constant_48_over_17 = 782 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 783 const F2 constant_neg_32_over_17 = 784 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 785 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 786 for (int i = 0; i < 3; i++) { 787 F2 half_denominator_times_x = half_denominator * x; 788 F2 one_minus_half_denominator_times_x = 789 F2::One() - half_denominator_times_x; 790 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 791 } 792 return Rescale<0>(x - F2::One()); 793 } 794 795 // Returns -tanh(x) for x < 0. 796 template <typename tRawType, int tIntegerBits> 797 FixedPoint<tRawType, 0> neg_tanh_on_negative_values( 798 FixedPoint<tRawType, tIntegerBits> a) { 799 return one_minus_x_over_one_plus_x_for_x_in_0_1( 800 exp_on_negative_values(ExactMulByPot<1>(a))); 801 } 802 803 // Returns tanh(x) for any x. 804 template <typename tRawType, int tIntegerBits> 805 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 806 typedef FixedPoint<tRawType, tIntegerBits> InputF; 807 typedef FixedPoint<tRawType, 0> ResultF; 808 tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 809 tRawType mask_if_zero = MaskIfZero(a); 810 InputF n = SelectUsingMask(mask_if_negative, a, -a); 811 ResultF t = neg_tanh_on_negative_values(n); 812 return SelectUsingMask(mask_if_zero, ResultF::Zero(), 813 SelectUsingMask(mask_if_negative, -t, t)); 814 } 815 816 // Implementation of logistic function. 817 818 // Returns 1 / (1 + x) for x in (0, 1). 819 template <typename tRawType> 820 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1( 821 FixedPoint<tRawType, 0> a) { 822 typedef FixedPoint<tRawType, 0> F0; 823 typedef FixedPoint<tRawType, 2> F2; 824 F0 half_denominator = RoundingHalfSum(a, F0::One()); 825 // Newton-Raphson division 826 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 827 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 828 const F2 constant_48_over_17 = 829 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 830 const F2 constant_neg_32_over_17 = 831 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 832 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 833 for (int i = 0; i < 3; i++) { 834 F2 half_denominator_times_x = half_denominator * x; 835 F2 one_minus_half_denominator_times_x = 836 F2::One() - half_denominator_times_x; 837 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 838 } 839 return Rescale<0>(ExactMulByPot<-1>(x)); 840 } 841 842 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. 843 template <typename tRawType, int tIntegerBits> 844 FixedPoint<tRawType, 0> logistic_on_positive_values( 845 FixedPoint<tRawType, tIntegerBits> a) { 846 return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); 847 } 848 849 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x. 850 template <typename tRawType, int tIntegerBits> 851 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { 852 typedef FixedPoint<tRawType, tIntegerBits> InputF; 853 typedef FixedPoint<tRawType, 0> ResultF; 854 tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); 855 tRawType mask_if_zero = MaskIfZero(a); 856 InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); 857 ResultF result_if_positive = logistic_on_positive_values(abs_input); 858 ResultF result_if_negative = ResultF::One() - result_if_positive; 859 const ResultF one_half = 860 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); 861 return SelectUsingMask(mask_if_zero, one_half, 862 SelectUsingMask(mask_if_positive, result_if_positive, 863 result_if_negative)); 864 } 865 866 } // end namespace gemmlowp 867 868 #ifdef GEMMLOWP_NEON 869 #include "./fixedpoint_neon.h" 870 #elif defined(GEMMLOWP_SSE4) 871 #include "./fixedpoint_sse.h" 872 #elif defined(GEMMLOWP_MSA) 873 #include "./fixedpoint_msa.h" 874 #endif 875 876 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 877