1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // fixedpoint.h: fixed-point arithmetic, with basic operations and
16 // a few math functions such as tanh.
17 
18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cmath>
24 #include <cstdint>
25 #include <limits>
26 
27 #include "../internal/detect_platform.h"
28 
29 namespace gemmlowp {
30 
31 // Part 1: Low-level integer-arithmetic primitives.
32 // The implementations here are generic implementations valid for
33 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
34 // (e.g. NEON int32x4_t) may be supported by providing
35 // specializations for them in separate files.
36 //
37 // The purpose of these primitives is two-fold:
38 //  - They will be used to implement higher-level fixed-point
39 //    abstractions, namely the FixedPoint class and its arithmetic
40 //    operators.
41 //  - They will be directly used to implement some more involved
42 //    fixed-point computations, e.g. the fixed-point implementation
43 //    of math functions such as tanh.
44 
45 // Some compile-time traits around raw types to handle SIMD aspects:
46 // number of lanes, underlying scalar type.
47 template <typename tIntegerType>
48 struct FixedPointRawTypeTraits {};
49 
50 template <>
51 struct FixedPointRawTypeTraits<std::int32_t> {
52   typedef std::int32_t ScalarRawType;
53   static constexpr int kLanes = 1;
54 };
55 
56 template <>
57 struct FixedPointRawTypeTraits<std::int16_t> {
58   typedef std::int16_t ScalarRawType;
59   static constexpr int kLanes = 1;
60 };
61 
62 // Returns a SIMD value duplicating a scalar value across all lanes.
63 template <typename tRawType>
64 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
65   return x;
66 }
67 
68 // Plain bit-wise AND
69 template <typename tIntegerType>
70 tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
71   return a & b;
72 }
73 
74 // Plain bit-wise OR
75 template <typename tIntegerType>
76 tIntegerType BitOr(tIntegerType a, tIntegerType b) {
77   return a | b;
78 }
79 
80 // Plain bit-wise XOR
81 template <typename tIntegerType>
82 tIntegerType BitXor(tIntegerType a, tIntegerType b) {
83   return a ^ b;
84 }
85 
86 // Plain bit-wise NOT
87 template <typename tIntegerType>
88 tIntegerType BitNot(tIntegerType a) {
89   return ~a;
90 }
91 
92 // Integer addition. Not saturating. Overflow is undefined behavior.
93 template <typename tIntegerType>
94 tIntegerType Add(tIntegerType a, tIntegerType b) {
95   return a + b;
96 }
97 
98 // Integer multiplication. Not saturating. Overflow is undefined behavior.
99 template <typename tIntegerType>
100 tIntegerType Mul(tIntegerType a, tIntegerType b) {
101   return a * b;
102 }
103 
104 // Integer subtraction. Not saturating. Overflow is undefined behavior.
105 template <typename tIntegerType>
106 tIntegerType Sub(tIntegerType a, tIntegerType b) {
107   return a - b;
108 }
109 
110 // Integer unary negative. Not saturating. Overflow is undefined behavior.
111 template <typename tIntegerType>
112 tIntegerType Neg(tIntegerType a) {
113   return -a;
114 }
115 
116 // Integer arithmetic left-shift, equivalent to multiplying with a power of two.
117 // Negative values are OK. In case of overflow, no Undefined
118 // Behavior, but the results are implementation-defined (in practice,
119 // they currently are saturated, but we make no commitment to that). The idea
120 // is that the caller will want to implement the overflowing cases with
121 // saturation with compare-and-mask, so we don't care about the results
122 // in the overflow case, we just want to avoid undefined behavior.
123 //
124 // tIntegerType may be int32 or any narrower signed type.
125 template <typename tIntegerType, typename OffsetType>
126 tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
127   const std::int64_t wide_a = static_cast<std::int64_t>(a);
128   const std::int64_t wide_shifted = wide_a * (1 << offset);
129   const auto min = std::numeric_limits<tIntegerType>::min();
130   const auto max = std::numeric_limits<tIntegerType>::max();
131   return wide_shifted < min
132              ? min
133              : wide_shifted > max ? max
134                                   : static_cast<tIntegerType>(wide_shifted);
135 }
136 
137 // Integer arithmetic right-shift. Not rounding.
138 // Relying on implementation-defined, but in-practice-consistent,
139 // C++ compiler behavior.
140 template <typename tIntegerType>
141 tIntegerType ShiftRight(tIntegerType a, int offset) {
142   return a >> offset;
143 }
144 
145 // Each bit of the result is set to the corresponding bit of either then_val or
146 // else_val depending on whether the corresponding bit of if_mask is set.
147 // Equivalent to the VBSL instruction in ARM NEON.
148 template <typename tIntegerType>
149 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
150                              tIntegerType else_val) {
151   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
152 }
153 
154 // For each input scalar, the corresponding bits of the result are set if the
155 // input scalar is non-zero.
156 template <typename tIntegerType>
157 tIntegerType MaskIfNonZero(tIntegerType a) {
158   static constexpr tIntegerType zero = 0;
159   return a ? BitNot(zero) : zero;
160 }
161 
162 // For each input scalar, the corresponding bits of the result are set if the
163 // input scalar is zero.
164 template <typename tIntegerType>
165 tIntegerType MaskIfZero(tIntegerType a) {
166   return MaskIfNonZero<tIntegerType>(!a);
167 }
168 
169 // For each pair of input scalars, the corresponding bits of the result are
170 // set if the input scalars are equal.
171 template <typename tIntegerType>
172 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
173   return MaskIfNonZero<tIntegerType>(a == b);
174 }
175 
176 // For each pair of input scalars, the corresponding bits of the result are
177 // set if the input scalars are not equal.
178 template <typename tIntegerType>
179 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
180   return MaskIfNonZero<tIntegerType>(a != b);
181 }
182 
183 // For each pair of input scalars, the corresponding bits of the result are
184 // set if the input scalars a, b satisfy a > b.
185 template <typename tIntegerType>
186 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
187   return MaskIfNonZero<tIntegerType>(a > b);
188 }
189 
190 // For each pair of input scalars, the corresponding bits of the result are
191 // set if the input scalars a, b satisfy a >= b.
192 template <typename tIntegerType>
193 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
194   return MaskIfNonZero<tIntegerType>(a >= b);
195 }
196 
197 // For each pair of input scalars, the corresponding bits of the result are
198 // set if the input scalars a, b satisfy a < b.
199 template <typename tIntegerType>
200 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
201   return MaskIfNonZero<tIntegerType>(a < b);
202 }
203 
204 // For each pair of input scalars, the corresponding bits of the result are
205 // set if the input scalars a, b satisfy a <= b.
206 template <typename tIntegerType>
207 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
208   return MaskIfNonZero<tIntegerType>(a <= b);
209 }
210 
211 // Returns true if all of the input scalars are nonzero.
212 // This function may currently assume that each of the input scalars has either
213 // all or none of its bits set. Otherwise, its behavior is currently undefined.
214 template <typename tIntegerType>
215 bool All(tIntegerType a) {
216   return a;
217 }
218 
219 // Returns true if any of the input scalars are nonzero.
220 // This function may currently assume that each of the input scalars has either
221 // all or none of its bits set. Otherwise, its behavior is currently undefined.
222 template <typename tIntegerType>
223 bool Any(tIntegerType a) {
224   return a;
225 }
226 
227 // Returns (a+b)/2, rounded to the nearest integer.
228 // Equivalent to VRHADD in the ARM NEON instruction set.
229 template <typename IntegerType>
230 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
231   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
232   (void)b;
233   return a;
234 }
235 
236 template <>
237 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
238   std::int64_t a64 = a;
239   std::int64_t b64 = b;
240   std::int64_t sum = a64 + b64;
241   std::int64_t sign = sum >= 0 ? 1 : -1;
242   return static_cast<std::int32_t>((sum + sign) / 2);
243 }
244 
245 template <>
246 inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
247   std::int32_t a32 = a;
248   std::int32_t b32 = b;
249   std::int32_t sum = a32 + b32;
250   std::int32_t sign = sum >= 0 ? 1 : -1;
251   return static_cast<std::int16_t>((sum + sign) / 2);
252 }
253 
254 template <typename IntegerType>
255 IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
256   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
257   (void)b;
258   return a;
259 }
260 
261 // So far this is only needed for int16.
262 template <>
263 inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
264   std::int32_t a32 = a;
265   std::int32_t b32 = b;
266   std::int32_t sum = a32 + b32;
267   return static_cast<std::int16_t>(
268       std::min(static_cast<std::int32_t>(32767),
269                std::max(static_cast<std::int32_t>(-32768), sum)));
270 }
271 
272 template <>
273 inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
274   std::int16_t a16 = a;
275   std::int16_t b16 = b;
276   std::int16_t sum = a16 + b16;
277   return static_cast<std::int8_t>(std::min(
278       static_cast<int16_t>(std::numeric_limits<int8_t>::max()),
279       std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum)));
280 }
281 
282 // Returns a+b, saturating if the integers are 16bit or narrower,
283 // otherwise just a plain addition.
284 template <typename IntegerType, bool Is16Bit>
285 struct AddSaturatingIf16BitImpl {
286   static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
287 };
288 template <typename IntegerType>
289 struct AddSaturatingIf16BitImpl<IntegerType, true> {
290   static IntegerType Run(IntegerType a, IntegerType b) {
291     return SaturatingAdd(a, b);
292   }
293 };
294 template <typename IntegerType>
295 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
296   using ScalarType =
297       typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
298   return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
299                                                                              b);
300 }
301 
302 // Returns the integer that represents the product of two fixed-point
303 // numbers, interpreting all integers as fixed-point values in the
304 // interval [-1, 1), rounding to the nearest value, and saturating
305 // -1 * -1 to the maximum value (since 1 is not in the half-open
306 // interval [-1, 1)).
307 //
308 // [The explanation below specializes to std::int32_t for example purpose.]
309 //
310 // The mapping between IntegerType and the interval [-1, 1) is unique and
311 // implied by IntegerType, which is assumed to be signed. For example,
312 // for IntegerType==std::int32_t, the mapping is
313 //   real_value = integer_value / 2^31.
314 // So in this case, and leaving aside rounding and saturating, this
315 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
316 //   (a * b) / 2^31.
317 //
318 // The 'doubling' part in the name of this function comes from the fact that
319 // this operation is very close to a "multiply-high" operation, keeping only
320 // the top half bits, except that that would be effectively computing
321 //   (a * b) / 2^32,
322 // so here we are computing 2x that, since
323 //   1/2^31 = 2 * 1/2^32.
324 // The idea is to use all of the available 32 bits in the destination int32
325 // value.
326 //
327 // [End of the explanation specializing to int32.]
328 //
329 // This is equivalent to the VQRDMULH instruction in ARM NEON.
330 template <typename IntegerType>
331 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
332   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
333   (void)b;
334   return a;
335 }
336 
337 // This function implements the same computation as the ARMv7 NEON VQRDMULH
338 // instruction.
339 template <>
340 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
341                                                       std::int32_t b) {
342   bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
343   std::int64_t a_64(a);
344   std::int64_t b_64(b);
345   std::int64_t ab_64 = a_64 * b_64;
346   std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
347   std::int32_t ab_x2_high32 =
348       static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
349   return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
350 }
351 
352 template <>
353 inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
354                                                       std::int16_t b) {
355   bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
356   std::int32_t a_32(a);
357   std::int32_t b_32(b);
358   std::int32_t ab_32 = a_32 * b_32;
359   std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
360   std::int16_t ab_x2_high16 =
361       static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
362   return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
363 }
364 
365 // Correctly-rounded-to-nearest division by a power-of-two.
366 // Also known as a rounding arithmetic right shift.
367 template <typename IntegerType, typename ExponentType>
368 inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
369   assert(exponent >= 0);
370   assert(exponent <= 31);
371   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
372   const IntegerType zero = Dup<IntegerType>(0);
373   const IntegerType one = Dup<IntegerType>(1);
374   const IntegerType remainder = BitAnd(x, mask);
375   const IntegerType threshold =
376       Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
377   return Add(ShiftRight(x, exponent),
378              BitAnd(MaskIfGreaterThan(remainder, threshold), one));
379 }
380 
381 // Returns the product of a run-time integer value by a compile-time power
382 // of two, with either a positive exponent (equivalent to an arithmetic
383 // left shift, saturating) or a negative exponent (equivalent to an arithmetic
384 // right shift, rounding to nearest).
385 template <int Exponent, typename IntegerType,
386           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
387 struct ImplSaturatingRoundingMultiplyByPOT {};
388 
389 template <int Exponent, typename IntegerType>
390 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
391   static IntegerType eval(IntegerType x) { return x; }
392 };
393 
394 template <int Exponent, typename IntegerType>
395 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
396   static IntegerType eval(IntegerType x) {
397     using ScalarIntegerType =
398         typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
399     const IntegerType min =
400         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
401     const IntegerType max =
402         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
403     const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
404 
405     const std::int32_t threshold =
406         ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
407     const IntegerType positive_mask =
408         MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
409     const IntegerType negative_mask =
410         MaskIfLessThan(x, Dup<IntegerType>(-threshold));
411 
412     IntegerType result = ShiftLeft(x, Exponent);
413     result = SelectUsingMask(positive_mask, max, result);
414     result = SelectUsingMask(negative_mask, min, result);
415     return result;
416   }
417 };
418 
419 template <int Exponent, typename IntegerType>
420 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
421   static IntegerType eval(IntegerType x) {
422     return RoundingDivideByPOT<IntegerType>(x, -Exponent);
423   }
424 };
425 
426 template <int Exponent, typename IntegerType>
427 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
428   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
429 }
430 
431 // Part 2: the FixedPoint class.
432 
433 // A FixedPoint object represents a fixed-point value stored in the underlying
434 // integer type tRawType, if tRawType is a plain scalar integer type.
435 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
436 // case a FixedPoint object represents a corresponding SIMD vector of fixed
437 // point values.
438 //
439 // tIntegerBits describes the range of the fixed-point format: if
440 // tIntegerBits == m then the range of representable values is the half-open
441 // interval [-2^m; 2^m) where the open boundary on the right side means that
442 // 2^m is not representable (how close the maximum representable value is to
443 // it, depends on bit-depth of tRawType).
444 //
445 // In "Q format notation",
446 //   https://en.wikipedia.org/wiki/Q_(number_format)
447 // we are describing the format
448 //   Qm.n
449 // where
450 //   m = tIntegerBits
451 // and
452 //   n = NumberOfBits(tRawType) - (m + 1)
453 // Note that the (m + 1) in the above line is because we adopt the convention
454 // that we count the integer bits exclusively of the sign bit; so (m + 1) is
455 // the total number of integer bits inclusive of the sign bit.
456 //
457 // Accordingly, the number of integral representable values in our range
458 //   [-2^m ; 2^m)
459 // is equal to 2^(m+1).
460 template <typename tRawType, int tIntegerBits>
461 class FixedPoint {
462  public:
463   typedef tRawType RawType;
464 
465   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
466   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
467 
468   static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
469   static constexpr int kIntegerBits = tIntegerBits;
470   static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
471   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
472                 "bad IntegerBits");
473 
474   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
475 
476   static const ScalarRawType ScalarRawMin() {
477     return std::numeric_limits<ScalarRawType>::min();
478   }
479 
480   static const ScalarRawType ScalarRawMax() {
481     return std::numeric_limits<ScalarRawType>::max();
482   }
483 
484   static const ScalarRawType RawMin() {
485     return VectorFromScalar(ScalarRawMin());
486   }
487 
488   static const ScalarRawType RawMax() {
489     return VectorFromScalar(ScalarRawMax());
490   }
491 
492   static FixedPoint FromRaw(RawType x) {
493     FixedPoint retval;
494     retval.raw() = x;
495     return retval;
496   }
497 
498   static FixedPoint FromScalarRaw(ScalarRawType x) {
499     FixedPoint retval;
500     retval.raw() = Dup<RawType>(x);
501     return retval;
502   }
503 
504   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
505     return FromScalarRaw(x.raw());
506   }
507 
508   template <int Exponent>
509   static FixedPoint ConstantPOT() {
510     static constexpr int kOffset = kFractionalBits + Exponent;
511     static_assert(
512         kOffset < 31,
513         "Constant not exactly representable in this fixed-point format");
514     return FromScalarRaw(ScalarRawType(1) << kOffset);
515   }
516 
517   static FixedPoint Zero() { return FromScalarRaw(0); }
518 
519   static FixedPoint One() {
520     return FromScalarRaw(
521         kIntegerBits == 0
522             ? ScalarRawMax()
523             : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
524   }
525 
526   static FixedPoint FromDouble(double x) {
527     const double min_bound = static_cast<double>(ScalarRawMin());
528     const double max_bound = static_cast<double>(ScalarRawMax());
529     return FromScalarRaw(static_cast<ScalarRawType>(std::min(
530         std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
531                  min_bound),
532         max_bound)));
533   }
534 
535   RawType raw() const { return i_; }
536   RawType& raw() { return i_; }
537 
538  private:
539   RawType i_;
540 };
541 
542 // Part 3: implementation of arithmetic operators for the
543 // FixedPoint class, and a few related functions.
544 
545 // A FixedPoint multiplication is just a
546 // SaturatingRoundingDoublingHighMul operation on the underlying
547 // raw integer values. The IntegerBits simply add up, as is obvious
548 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
549 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
550 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
551     FixedPoint<tRawType, tIntegerBits_a> a,
552     FixedPoint<tRawType, tIntegerBits_b> b) {
553   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
554   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
555   return c;
556 }
557 
558 // Tweaking IntegerBits gives exact multiplication by a power of two.
559 template <int tExponent, typename tRawType, int tIntegerBits>
560 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
561     FixedPoint<tRawType, tIntegerBits> a) {
562   FixedPoint<tRawType, tExponent + tIntegerBits> c;
563   c.raw() = a.raw();
564   return c;
565 }
566 
567 // If we want to leave IntegerBits fixed, then multiplication
568 // by a power of two has to be saturating/rounding, not exact anymore.
569 template <int tExponent, typename tRawType, int tIntegerBits>
570 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
571     FixedPoint<tRawType, tIntegerBits> a) {
572   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
573       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
574 }
575 
576 // Generic arithmetic operators.
577 
578 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
579   template <typename tRawType, int tIntegerBits>                               \
580   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
581       FixedPoint<tRawType, tIntegerBits> a) {                                  \
582     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
583   }
584 
585 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
586   template <typename tRawType, int tIntegerBits>            \
587   FixedPoint<tRawType, tIntegerBits> FuncName(              \
588       FixedPoint<tRawType, tIntegerBits> a,                 \
589       FixedPoint<tRawType, tIntegerBits> b) {               \
590     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
591         ImplFuncName(a.raw(), b.raw()));                    \
592   }
593 
594 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
595 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
596 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
597 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
598 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
599 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
600 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
601 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
602 
603 #undef MAKE_FIXEDPOINT_UNARY_FUNC
604 #undef MAKE_FIXEDPOINT_BINARY_FUNC
605 
606 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
607   template <typename tRawType, int tIntegerBits>            \
608   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
609     return FuncName(a.raw());                               \
610   }
611 
612 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
613   template <typename tRawType, int tIntegerBits>            \
614   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
615                     FixedPoint<tRawType, tIntegerBits> b) { \
616     return FuncName(a.raw(), b.raw());                      \
617   }
618 
619 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
620 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
621 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
622 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
623 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
624 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
625 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
626 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
627 
628 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
629 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
630 
631 template <typename tRawType, int tIntegerBits>
632 FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
633     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
634     FixedPoint<tRawType, tIntegerBits> else_val) {
635   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
636       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
637 }
638 
639 template <typename tRawType, int tIntegerBits>
640 bool operator==(FixedPoint<tRawType, tIntegerBits> a,
641                 FixedPoint<tRawType, tIntegerBits> b) {
642   return All(MaskIfEqual(a.raw(), b.raw()));
643 }
644 
645 template <typename tRawType, int tIntegerBits>
646 bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
647                 FixedPoint<tRawType, tIntegerBits> b) {
648   return !(a == b);
649 }
650 
651 template <typename tRawType, int tIntegerBits>
652 FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
653     FixedPoint<tRawType, tIntegerBits> a,
654     FixedPoint<tRawType, tIntegerBits> b) {
655   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
656       SaturatingAdd(a.raw(), b.raw()));
657 }
658 
659 template <typename tRawType, int tIntegerBits>
660 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
661     FixedPoint<tRawType, tIntegerBits> a,
662     FixedPoint<tRawType, tIntegerBits> b) {
663   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
664       AddSaturatingIf16Bit(a.raw(), b.raw()));
665 }
666 
667 // Conversion to floating-point.
668 template <typename tRawType, int tIntegerBits>
669 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
670   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
671                 "not applicable to SIMD types");
672   typedef FixedPoint<tRawType, tIntegerBits> F;
673   return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
674 }
675 
676 // Rescale changes the number of IntegerBits and updates the underlying
677 // raw integer value accordingly.
678 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
679 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
680     FixedPoint<tRawType, tIntegerBitsSrc> x) {
681   static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
682   FixedPoint<tRawType, tIntegerBitsDst> result;
683   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
684   return result;
685 }
686 
687 // CheckedFixedPointConstant allows to specify fixed-point constants
688 // initialized as real numbers, in a way that does not compile floating-point
689 // arithmetic in production code, yet still checks agreement with the
690 // floating-point expressions when asserts are enabled.
691 //
692 // The raw integer value provided is always a int32, encoding a 32-bit
693 // fixed-point value, regardless of the actual Scalar type. This allows
694 // writing generic code that applies just as well to the 32-bit and 16-bit
695 // cases. In the 16-bit case, the raw integer value is internally
696 // rounding-shifted by 16 bits to the right.
697 template <typename FixedPointType>
698 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
699     std::int32_t int32_value) {
700   typedef typename FixedPointType::ScalarRawType ScalarRawType;
701   static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
702   return static_cast<ScalarRawType>(
703       RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
704 }
705 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
706 template <typename FixedPointType>
707 FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
708                                          double double_value) {
709   const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
710   assert(result == FixedPointType::FromDouble(double_value));
711   return result;
712 }
713 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
714                                              ScalarRawInt32Value, DoubleValue) \
715   (gemmlowp::CheckedFixedPointConstant<FixedPointType>(                        \
716       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
717           ScalarRawInt32Value),                                                \
718       DoubleValue))
719 
720 #else
721 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
722                                              ScalarRawInt32Value, DoubleValue) \
723   (FixedPointType::FromScalarRaw(                                              \
724       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
725           ScalarRawInt32Value)))
726 #endif
727 
728 // Implementation of exponential function.
729 
730 // Returns exp(x) for x in [-1/4, 0).
731 template <typename tRawType>
732 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
733     FixedPoint<tRawType, 0> a) {
734   typedef FixedPoint<tRawType, 0> F;
735   const F constant_term =
736       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
737   const F constant_1_over_3 =
738       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
739   // We're evaluating a Taylor expansion around -1/8, so we do the change of
740   // variable: x = a + 1/8.
741   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
742   F x = a + F::template ConstantPOT<-3>();
743   F x2 = x * x;
744   F x3 = x2 * x;
745   F x4 = x2 * x2;
746   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
747   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
748       SaturatingRoundingMultiplyByPOT<-1>(
749           ((x4_over_4 + x3) * constant_1_over_3) + x2);
750   return AddSaturatingIf16Bit(
751       constant_term,
752       constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
753 }
754 
755 // Returns exp(x) for x < 0.
756 template <typename tRawType, int tIntegerBits>
757 FixedPoint<tRawType, 0> exp_on_negative_values(
758     FixedPoint<tRawType, tIntegerBits> a) {
759   typedef FixedPoint<tRawType, tIntegerBits> InputF;
760   typedef FixedPoint<tRawType, 0> ResultF;
761   static constexpr int kFractionalBits = InputF::kFractionalBits;
762   static constexpr int kIntegerBits = InputF::kIntegerBits;
763   const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
764   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
765   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
766   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
767       Rescale<0>(a_mod_quarter_minus_one_quarter));
768   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
769 
770 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
771   if (kIntegerBits > Exponent) {                                            \
772     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
773         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
774     static constexpr int kShiftAmount =                                     \
775         kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
776     result = SelectUsingMask(                                               \
777         MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
778         result * kMultiplier, result);                                      \
779   }
780 
781   // Constants below are Q0 representations of negative exp fractionals:
782   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);  // exp(-1/4)
783   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);  // exp(-1/2)
784   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);   // exp(-1)
785   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);   // exp(-2)
786   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);    // exp(-4)
787   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);      // exp(-8)
788   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);         // exp(-16)
789 
790 #undef GEMMLOWP_EXP_BARREL_SHIFTER
791 
792   static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
793   if (kIntegerBits > 5) {
794     const InputF clamp =
795         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
796     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
797   }
798 
799   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
800   return result;
801 }
802 
803 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
804 
805 // Returns (1 - x) / (1 + x) for x in (0, 1).
806 template <typename tRawType>
807 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
808     FixedPoint<tRawType, 0> a) {
809   typedef FixedPoint<tRawType, 0> F0;
810   typedef FixedPoint<tRawType, 2> F2;
811   F0 half_denominator = RoundingHalfSum(a, F0::One());
812   // Newton-Raphson division
813   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
814   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
815   const F2 constant_48_over_17 =
816       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
817   const F2 constant_neg_32_over_17 =
818       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
819   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
820   for (int i = 0; i < 3; i++) {
821     F2 half_denominator_times_x = half_denominator * x;
822     F2 one_minus_half_denominator_times_x =
823         F2::One() - half_denominator_times_x;
824     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
825   }
826   return Rescale<0>(x - F2::One());
827 }
828 
829 // Returns -tanh(x) for x < 0.
830 template <typename tRawType, int tIntegerBits>
831 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
832     FixedPoint<tRawType, tIntegerBits> a) {
833   return one_minus_x_over_one_plus_x_for_x_in_0_1(
834       exp_on_negative_values(ExactMulByPot<1>(a)));
835 }
836 
837 // Returns tanh(x) for any x.
838 template <typename tRawType, int tIntegerBits>
839 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
840   typedef FixedPoint<tRawType, tIntegerBits> InputF;
841   typedef FixedPoint<tRawType, 0> ResultF;
842   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
843   tRawType mask_if_zero = MaskIfZero(a);
844   InputF n = SelectUsingMask(mask_if_negative, a, -a);
845   ResultF t = neg_tanh_on_negative_values(n);
846   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
847                          SelectUsingMask(mask_if_negative, -t, t));
848 }
849 
850 // Implementation of logistic function.
851 
852 // Returns 1 / (1 + x) for x in (0, 1).
853 template <typename tRawType>
854 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
855     FixedPoint<tRawType, 0> a) {
856   typedef FixedPoint<tRawType, 0> F0;
857   typedef FixedPoint<tRawType, 2> F2;
858   F0 half_denominator = RoundingHalfSum(a, F0::One());
859   // Newton-Raphson division
860   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
861   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
862   const F2 constant_48_over_17 =
863       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
864   const F2 constant_neg_32_over_17 =
865       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
866   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
867   for (int i = 0; i < 3; i++) {
868     F2 half_denominator_times_x = half_denominator * x;
869     F2 one_minus_half_denominator_times_x =
870         F2::One() - half_denominator_times_x;
871     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
872   }
873   return Rescale<0>(ExactMulByPot<-1>(x));
874 }
875 
876 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
877 template <typename tRawType, int tIntegerBits>
878 FixedPoint<tRawType, 0> logistic_on_positive_values(
879     FixedPoint<tRawType, tIntegerBits> a) {
880   return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
881 }
882 
883 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
884 template <typename tRawType, int tIntegerBits>
885 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
886   typedef FixedPoint<tRawType, tIntegerBits> InputF;
887   typedef FixedPoint<tRawType, 0> ResultF;
888   tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
889   tRawType mask_if_zero = MaskIfZero(a);
890   InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
891   ResultF result_if_positive = logistic_on_positive_values(abs_input);
892   ResultF result_if_negative = ResultF::One() - result_if_positive;
893   const ResultF one_half =
894       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
895   return SelectUsingMask(mask_if_zero, one_half,
896                          SelectUsingMask(mask_if_positive, result_if_positive,
897                                          result_if_negative));
898 }
899 
900 }  // end namespace gemmlowp
901 
902 #ifdef GEMMLOWP_NEON
903 #include "./fixedpoint_neon.h"
904 #elif defined(GEMMLOWP_AVX2)
905 #include "./fixedpoint_avx.h"
906 #elif defined(GEMMLOWP_SSE4)
907 #include "./fixedpoint_sse.h"
908 #elif defined(GEMMLOWP_MSA)
909 #include "./fixedpoint_msa.h"
910 #elif defined(GEMMLOWP_WASMSIMD)
911 #include "./fixedpoint_wasmsimd.h"
912 #endif
913 
914 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
915