1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // fixedpoint.h: fixed-point arithmetic, with basic operations and
16 // a few math functions such as tanh.
17 
18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
20 
21 #include <cassert>
22 #include <limits>
23 
24 #include "../internal/common.h"
25 
26 namespace gemmlowp {
27 
28 // Part 1: Low-level integer-arithmetic primitives.
29 // The implementations here are generic implementations valid for
30 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
31 // (e.g. NEON int32x4_t) may be supported by providing
32 // specializations for them in separate files.
33 //
34 // The purpose of these primitives is two-fold:
35 //  - They will be used to implement higher-level fixed-point
36 //    abstractions, namely the FixedPoint class and its arithmetic
37 //    operators.
38 //  - They will be directly used to implement some more involved
39 //    fixed-point computations, e.g. the fixed-point implementation
40 //    of math functions such as tanh.
41 
42 // Some compile-time traits around raw types to handle SIMD aspects:
43 // number of lanes, underlying scalar type.
44 template <typename tIntegerType>
45 struct FixedPointRawTypeTraits {};
46 
47 template <>
48 struct FixedPointRawTypeTraits<std::int32_t> {
49   typedef std::int32_t ScalarRawType;
50   static const int kLanes = 1;
51 };
52 
53 template <>
54 struct FixedPointRawTypeTraits<std::int16_t> {
55   typedef std::int16_t ScalarRawType;
56   static const int kLanes = 1;
57 };
58 
59 // Returns a SIMD value duplicating a scalar value across all lanes.
60 template <typename tRawType>
61 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
62   return x;
63 }
64 
65 // Plain bit-wise AND
66 template <typename tIntegerType>
67 tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
68   return a & b;
69 }
70 
71 // Plain bit-wise OR
72 template <typename tIntegerType>
73 tIntegerType BitOr(tIntegerType a, tIntegerType b) {
74   return a | b;
75 }
76 
77 // Plain bit-wise XOR
78 template <typename tIntegerType>
79 tIntegerType BitXor(tIntegerType a, tIntegerType b) {
80   return a ^ b;
81 }
82 
83 // Plain bit-wise NOT
84 template <typename tIntegerType>
85 tIntegerType BitNot(tIntegerType a) {
86   return ~a;
87 }
88 
89 // Integer addition. Not saturating. Overflow is undefined behavior.
90 template <typename tIntegerType>
91 tIntegerType Add(tIntegerType a, tIntegerType b) {
92   return a + b;
93 }
94 
95 // Integer subtraction. Not saturating. Overflow is undefined behavior.
96 template <typename tIntegerType>
97 tIntegerType Mul(tIntegerType a, tIntegerType b) {
98   return a * b;
99 }
100 
101 template <typename tIntegerType>
102 tIntegerType Sub(tIntegerType a, tIntegerType b) {
103   return a - b;
104 }
105 
106 // Integer unary negative. Not saturating. Overflow is undefined behavior.
107 template <typename tIntegerType>
108 tIntegerType Neg(tIntegerType a) {
109   return -a;
110 }
111 
112 // Integer arithmetic left-shift, equivalent to multiplying with a
113 // power of two. Not saturating. Overflow is undefined behavior.
114 template <typename tIntegerType>
115 tIntegerType ShiftLeft(tIntegerType a, int offset) {
116   return a << offset;
117 }
118 
119 // Integer arithmetic right-shift. Not rounding.
120 // Relying on implementation-defined, but in-practice-consistent,
121 // C++ compiler behavior.
122 template <typename tIntegerType>
123 tIntegerType ShiftRight(tIntegerType a, int offset) {
124   return a >> offset;
125 }
126 
127 // Each bit of the result is set to the corresponding bit of either then_val or
128 // else_val depending on whether the corresponding bit of if_mask is set.
129 // Equivalent to the VBSL instruction in ARM NEON.
130 template <typename tIntegerType>
131 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
132                              tIntegerType else_val) {
133   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
134 }
135 
136 // For each input scalar, the corresponding bits of the result are set if the
137 // input scalar is non-zero.
138 template <typename tIntegerType>
139 tIntegerType MaskIfNonZero(tIntegerType a) {
140   static const tIntegerType zero = 0;
141   return a ? BitNot(zero) : zero;
142 }
143 
144 // For each input scalar, the corresponding bits of the result are set if the
145 // input scalar is zero.
146 template <typename tIntegerType>
147 tIntegerType MaskIfZero(tIntegerType a) {
148   return MaskIfNonZero<tIntegerType>(!a);
149 }
150 
151 // For each pair of input scalars, the corresponding bits of the result are
152 // set if the input scalars are equal.
153 template <typename tIntegerType>
154 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
155   return MaskIfNonZero<tIntegerType>(a == b);
156 }
157 
158 // For each pair of input scalars, the corresponding bits of the result are
159 // set if the input scalars are not equal.
160 template <typename tIntegerType>
161 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
162   return MaskIfNonZero<tIntegerType>(a != b);
163 }
164 
165 // For each pair of input scalars, the corresponding bits of the result are
166 // set if the input scalars a, b satisfy a > b.
167 template <typename tIntegerType>
168 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
169   return MaskIfNonZero<tIntegerType>(a > b);
170 }
171 
172 // For each pair of input scalars, the corresponding bits of the result are
173 // set if the input scalars a, b satisfy a >= b.
174 template <typename tIntegerType>
175 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
176   return MaskIfNonZero<tIntegerType>(a >= b);
177 }
178 
179 // For each pair of input scalars, the corresponding bits of the result are
180 // set if the input scalars a, b satisfy a < b.
181 template <typename tIntegerType>
182 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
183   return MaskIfNonZero<tIntegerType>(a < b);
184 }
185 
186 // For each pair of input scalars, the corresponding bits of the result are
187 // set if the input scalars a, b satisfy a <= b.
188 template <typename tIntegerType>
189 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
190   return MaskIfNonZero<tIntegerType>(a <= b);
191 }
192 
193 // Returns true if all of the input scalars are nonzero.
194 // This function may currently assume that each of the input scalars has either
195 // all or none of its bits set. Otherwise, its behavior is currently undefined.
196 template <typename tIntegerType>
197 bool All(tIntegerType a) {
198   return a;
199 }
200 
201 // Returns true if any of the input scalars are nonzero.
202 // This function may currently assume that each of the input scalars has either
203 // all or none of its bits set. Otherwise, its behavior is currently undefined.
204 template <typename tIntegerType>
205 bool Any(tIntegerType a) {
206   return a;
207 }
208 
209 // Returns (a+b)/2, rounded to the nearest integer.
210 // Equivalent to VRHADD in the ARM NEON instruction set.
211 template <typename IntegerType>
212 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
213   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
214   return a;
215 }
216 
217 template <>
218 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
219   std::int64_t a64 = a;
220   std::int64_t b64 = b;
221   std::int64_t sum = a64 + b64;
222   std::int64_t sign = sum >= 0 ? 1 : -1;
223   return static_cast<std::int32_t>((sum + sign) / 2);
224 }
225 
226 template <>
227 inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
228   std::int32_t a32 = a;
229   std::int32_t b32 = b;
230   std::int32_t sum = a32 + b32;
231   std::int32_t sign = sum >= 0 ? 1 : -1;
232   return static_cast<std::int16_t>((sum + sign) / 2);
233 }
234 
235 template <typename IntegerType>
236 IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
237   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
238   return a;
239 }
240 
241 // So far this is only needed for int16.
242 template <>
243 inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
244   std::int32_t a32 = a;
245   std::int32_t b32 = b;
246   std::int32_t sum = a32 + b32;
247   return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
248 }
249 
250 // Returns a+b, saturating if the integers are 16bit or narrower,
251 // otherwise just a plain addition.
252 template <typename IntegerType, bool Is16Bit>
253 struct AddSaturatingIf16BitImpl {
254   static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
255 };
256 template <typename IntegerType>
257 struct AddSaturatingIf16BitImpl<IntegerType, true> {
258   static IntegerType Run(IntegerType a, IntegerType b) {
259     return SaturatingAdd(a, b);
260   }
261 };
262 template <typename IntegerType>
263 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
264   using ScalarType =
265       typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
266   return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
267                                                                              b);
268 }
269 
270 // Returns the integer that represents the product of two fixed-point
271 // numbers, interpreting all integers as fixed-point values in the
272 // interval [-1, 1), rounding to the nearest value, and saturating
273 // -1 * -1 to the maximum value (since 1 is not in the half-open
274 // interval [-1, 1)).
275 //
276 // [The explanation below specializes to std::int32_t for example purpose.]
277 //
278 // The mapping between IntegerType and the interval [-1, 1) is unique and
279 // implied by IntegerType, which is assumed to be signed. For example,
280 // for IntegerType==std::int32_t, the mapping is
281 //   real_value = integer_value / 2^31.
282 // So in this case, and leaving aside rounding and saturating, this
283 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
284 //   (a * b) / 2^31.
285 //
286 // The 'doubling' part in the name of this function comes from the fact that
287 // this operation is very close to a "multiply-high" operation, keeping only
288 // the top half bits, except that that would be effectively computing
289 //   (a * b) / 2^32,
290 // so here we are computing 2x that, since
291 //   1/2^31 = 2 * 1/2^32.
292 // The idea is to use all of the available 32 bits in the destination int32
293 // value.
294 //
295 // [End of the explanation specializing to int32.]
296 //
297 // This is equivalent to the VQRDMULH instruction in ARM NEON.
298 template <typename IntegerType>
299 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
300   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
301   return a;
302 }
303 
304 // This function implements the same computation as the ARMv7 NEON VQRDMULH
305 // instruction.
306 template <>
307 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
308                                                       std::int32_t b) {
309   bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
310   std::int64_t a_64(a);
311   std::int64_t b_64(b);
312   std::int64_t ab_64 = a_64 * b_64;
313   std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
314   std::int32_t ab_x2_high32 =
315       static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
316   return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
317 }
318 
319 template <>
320 inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
321                                                       std::int16_t b) {
322   bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
323   std::int32_t a_32(a);
324   std::int32_t b_32(b);
325   std::int32_t ab_32 = a_32 * b_32;
326   std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
327   std::int16_t ab_x2_high16 =
328       static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
329   return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
330 }
331 
332 // Correctly-rounded-to-nearest division by a power-of-two.
333 // Also known as a rounding arithmetic right shift.
334 template <typename IntegerType>
335 inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
336   assert(exponent >= 0);
337   assert(exponent <= 31);
338   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
339   const IntegerType zero = Dup<IntegerType>(0);
340   const IntegerType one = Dup<IntegerType>(1);
341   const IntegerType remainder = BitAnd(x, mask);
342   const IntegerType threshold =
343       Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
344   return Add(ShiftRight(x, exponent),
345              BitAnd(MaskIfGreaterThan(remainder, threshold), one));
346 }
347 
348 // Returns the product of a run-time integer value by a compile-time power
349 // of two, with either a positive exponent (equivalent to an arithmetic
350 // left shift, saturating) or a negative exponent (equivalent to an arithmetic
351 // right shift, rounding to nearest).
352 template <int Exponent, typename IntegerType,
353           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
354 struct ImplSaturatingRoundingMultiplyByPOT {};
355 
356 template <int Exponent, typename IntegerType>
357 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
358   static IntegerType eval(IntegerType x) { return x; }
359 };
360 
361 template <int Exponent, typename IntegerType>
362 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
363   static IntegerType eval(IntegerType x) {
364     using ScalarIntegerType =
365         typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
366     const IntegerType min =
367         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
368     const IntegerType max =
369         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
370     const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
371 
372     const std::int32_t threshold =
373         ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
374     const IntegerType positive_mask =
375         MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
376     const IntegerType negative_mask =
377         MaskIfLessThan(x, Dup<IntegerType>(-threshold));
378 
379     IntegerType result = ShiftLeft(x, Exponent);
380     result = SelectUsingMask(positive_mask, max, result);
381     result = SelectUsingMask(negative_mask, min, result);
382     return result;
383   }
384 };
385 
386 template <int Exponent, typename IntegerType>
387 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
388   static IntegerType eval(IntegerType x) {
389     return RoundingDivideByPOT<IntegerType>(x, -Exponent);
390   }
391 };
392 
393 template <int Exponent, typename IntegerType>
394 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
395   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
396 }
397 
398 // Part 2: the FixedPoint class.
399 
400 // A FixedPoint object represents a fixed-point value stored in the underlying
401 // integer type tRawType, if tRawType is a plain scalar integer type.
402 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
403 // case a FixedPoint object represents a corresponding SIMD vector of fixed
404 // point values.
405 //
406 // tIntegerBits describes the range of the fixed-point format: if
407 // tIntegerBits == m then the range of representable values is the half-open
408 // interval [-2^m; 2^m) where the open boundary on the right side means that
409 // 2^m is not representable (how close the maximum representable value is to
410 // it, depends on bit-depth of tRawType).
411 //
412 // In "Q format notation",
413 //   https://en.wikipedia.org/wiki/Q_(number_format)
414 // we are describing the format
415 //   Qm.n
416 // where
417 //   m = tIntegerBits
418 // and
419 //   n = NumberOfBits(tRawType) - (m + 1)
420 // Note that the (m + 1) in the above line is because we adopt the convention
421 // that we count the integer bits exclusively of the sign bit; so (m + 1) is
422 // the total number of integer bits inclusive of the sign bit.
423 //
424 // Accordingly, the number of integral representable values in our range
425 //   [-2^m ; 2^m)
426 // is equal to 2^(m+1).
427 template <typename tRawType, int tIntegerBits>
428 class FixedPoint {
429  public:
430   typedef tRawType RawType;
431 
432   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
433   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
434 
435   static const int kTotalBits = 8 * sizeof(ScalarRawType);
436   static const int kIntegerBits = tIntegerBits;
437   static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
438   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
439                 "bad IntegerBits");
440 
441   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
442 
443   static const ScalarRawType ScalarRawMin() {
444     return std::numeric_limits<ScalarRawType>::min();
445   }
446 
447   static const ScalarRawType ScalarRawMax() {
448     return std::numeric_limits<ScalarRawType>::max();
449   }
450 
451   static const ScalarRawType RawMin() {
452     return VectorFromScalar(ScalarRawMin());
453   }
454 
455   static const ScalarRawType RawMax() {
456     return VectorFromScalar(ScalarRawMax());
457   }
458 
459   static FixedPoint FromRaw(RawType x) {
460     FixedPoint retval;
461     retval.raw() = x;
462     return retval;
463   }
464 
465   static FixedPoint FromScalarRaw(ScalarRawType x) {
466     FixedPoint retval;
467     retval.raw() = Dup<RawType>(x);
468     return retval;
469   }
470 
471   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
472     return FromScalarRaw(x.raw());
473   }
474 
475   template <int Exponent>
476   static FixedPoint ConstantPOT() {
477     static const int kOffset = kFractionalBits + Exponent;
478     static_assert(
479         kOffset < 31,
480         "Constant not exactly representable in this fixed-point format");
481     return FromScalarRaw(ScalarRawType(1) << kOffset);
482   }
483 
484   static FixedPoint Zero() { return FromScalarRaw(0); }
485 
486   static FixedPoint One() {
487     return FromScalarRaw(
488         kIntegerBits == 0
489             ? ScalarRawMax()
490             : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
491   }
492 
493   static FixedPoint FromDouble(double x) {
494     const double min_bound = static_cast<double>(ScalarRawMin());
495     const double max_bound = static_cast<double>(ScalarRawMax());
496     return FromScalarRaw(static_cast<ScalarRawType>(std::min(
497         std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
498                  min_bound),
499         max_bound)));
500   }
501 
502   RawType raw() const { return i_; }
503   RawType& raw() { return i_; }
504 
505  private:
506   RawType i_;
507 };
508 
509 // Part 3: implementation of arithmetic operators for the
510 // FixedPoint class, and a few related functions.
511 
512 // A FixedPoint multiplication is just a
513 // SaturatingRoundingDoublingHighMul operation on the underlying
514 // raw integer values. The IntegerBits simply add up, as is obvious
515 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
516 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
517 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
518     FixedPoint<tRawType, tIntegerBits_a> a,
519     FixedPoint<tRawType, tIntegerBits_b> b) {
520   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
521   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
522   return c;
523 }
524 
525 // Tweaking IntegerBits gives exact multiplication by a power of two.
526 template <int tExponent, typename tRawType, int tIntegerBits>
527 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
528     FixedPoint<tRawType, tIntegerBits> a) {
529   FixedPoint<tRawType, tExponent + tIntegerBits> c;
530   c.raw() = a.raw();
531   return c;
532 }
533 
534 // If we want to leave IntegerBits fixed, then multiplication
535 // by a power of two has to be saturating/rounding, not exact anymore.
536 template <int tExponent, typename tRawType, int tIntegerBits>
537 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
538     FixedPoint<tRawType, tIntegerBits> a) {
539   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
540       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
541 }
542 
543 // Generic arithmetic operators.
544 
545 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
546   template <typename tRawType, int tIntegerBits>                               \
547   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
548       FixedPoint<tRawType, tIntegerBits> a) {                                  \
549     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
550   }
551 
552 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
553   template <typename tRawType, int tIntegerBits>            \
554   FixedPoint<tRawType, tIntegerBits> FuncName(              \
555       FixedPoint<tRawType, tIntegerBits> a,                 \
556       FixedPoint<tRawType, tIntegerBits> b) {               \
557     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
558         ImplFuncName(a.raw(), b.raw()));                    \
559   }
560 
561 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
562 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
563 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
564 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
565 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
566 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
567 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
568 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
569 
570 #undef MAKE_FIXEDPOINT_UNARY_FUNC
571 #undef MAKE_FIXEDPOINT_BINARY_FUNC
572 
573 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
574   template <typename tRawType, int tIntegerBits>            \
575   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
576     return FuncName(a.raw());                               \
577   }
578 
579 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
580   template <typename tRawType, int tIntegerBits>            \
581   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
582                     FixedPoint<tRawType, tIntegerBits> b) { \
583     return FuncName(a.raw(), b.raw());                      \
584   }
585 
586 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
587 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
588 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
589 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
590 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
591 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
592 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
593 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
594 
595 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
596 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
597 
598 template <typename tRawType, int tIntegerBits>
599 FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
600     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
601     FixedPoint<tRawType, tIntegerBits> else_val) {
602   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
603       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
604 }
605 
606 template <typename tRawType, int tIntegerBits>
607 bool operator==(FixedPoint<tRawType, tIntegerBits> a,
608                 FixedPoint<tRawType, tIntegerBits> b) {
609   return All(MaskIfEqual(a.raw(), b.raw()));
610 }
611 
612 template <typename tRawType, int tIntegerBits>
613 bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
614                 FixedPoint<tRawType, tIntegerBits> b) {
615   return !(a == b);
616 }
617 
618 template <typename tRawType, int tIntegerBits>
619 FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
620     FixedPoint<tRawType, tIntegerBits> a,
621     FixedPoint<tRawType, tIntegerBits> b) {
622   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
623       SaturatingAdd(a.raw(), b.raw()));
624 }
625 
626 template <typename tRawType, int tIntegerBits>
627 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
628     FixedPoint<tRawType, tIntegerBits> a,
629     FixedPoint<tRawType, tIntegerBits> b) {
630   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
631       AddSaturatingIf16Bit(a.raw(), b.raw()));
632 }
633 
634 // Conversion to floating-point.
635 template <typename tRawType, int tIntegerBits>
636 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
637   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
638                 "not applicable to SIMD types");
639   typedef FixedPoint<tRawType, tIntegerBits> F;
640   return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
641 }
642 
643 // Rescale changes the number of IntegerBits and updates the underlying
644 // raw integer value accordingly.
645 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
646 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
647     FixedPoint<tRawType, tIntegerBitsSrc> x) {
648   static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
649   FixedPoint<tRawType, tIntegerBitsDst> result;
650   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
651   return result;
652 }
653 
654 // CheckedFixedPointConstant allows to specify fixed-point constants
655 // initialized as real numbers, in a way that does not compile floating-point
656 // arithmetic in production code, yet still checks agreement with the
657 // floating-point expressions when asserts are enabled.
658 //
659 // The raw integer value provided is always a int32, encoding a 32-bit
660 // fixed-point value, regardless of the actual Scalar type. This allows
661 // writing generic code that applies just as well to the 32-bit and 16-bit
662 // cases. In the 16-bit case, the raw integer value is internally
663 // rounding-shifted by 16 bits to the right.
664 template <typename FixedPointType>
665 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
666     std::int32_t int32_value) {
667   typedef typename FixedPointType::ScalarRawType ScalarRawType;
668   static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
669   return static_cast<ScalarRawType>(
670       RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
671 }
672 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
673 template <typename FixedPointType>
674 FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
675                                          double double_value) {
676   const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
677   assert(result == FixedPointType::FromDouble(double_value));
678   return result;
679 }
680 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
681                                              ScalarRawInt32Value, DoubleValue) \
682   (gemmlowp::CheckedFixedPointConstant<FixedPointType>(                        \
683       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
684           ScalarRawInt32Value),                                                \
685       DoubleValue))
686 
687 #else
688 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
689                                              ScalarRawInt32Value, DoubleValue) \
690   (FixedPointType::FromScalarRaw(                                              \
691       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
692           ScalarRawInt32Value)))
693 #endif
694 
695 // Implementation of exponential function.
696 
697 // Returns exp(x) for x in [-1/4, 0).
698 template <typename tRawType>
699 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
700     FixedPoint<tRawType, 0> a) {
701   typedef FixedPoint<tRawType, 0> F;
702   const F constant_term =
703       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
704   const F constant_1_over_3 =
705       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
706   // We're evaluating a Taylor expansion around -1/8, so we do the change of
707   // variable: x = a + 1/8.
708   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
709   F x = a + F::template ConstantPOT<-3>();
710   F x2 = x * x;
711   F x3 = x2 * x;
712   F x4 = x2 * x2;
713   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
714   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
715       SaturatingRoundingMultiplyByPOT<-1>(
716           ((x4_over_4 + x3) * constant_1_over_3) + x2);
717   return AddSaturatingIf16Bit(
718       constant_term,
719       constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
720 }
721 
722 // Returns exp(x) for x < 0.
723 template <typename tRawType, int tIntegerBits>
724 FixedPoint<tRawType, 0> exp_on_negative_values(
725     FixedPoint<tRawType, tIntegerBits> a) {
726   typedef FixedPoint<tRawType, tIntegerBits> InputF;
727   typedef FixedPoint<tRawType, 0> ResultF;
728   static const int kFractionalBits = InputF::kFractionalBits;
729   static const int kIntegerBits = InputF::kIntegerBits;
730   static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
731   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
732   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
733   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
734       Rescale<0>(a_mod_quarter_minus_one_quarter));
735   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
736 
737 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
738   if (kIntegerBits > Exponent) {                                            \
739     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
740         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
741     static constexpr int kShiftAmount =                                     \
742         kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
743     result = SelectUsingMask(                                               \
744         MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
745         result * kMultiplier, result);                                      \
746   }
747 
748   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
749   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
750   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
751   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
752   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
753   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
754   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
755 
756 #undef GEMMLOWP_EXP_BARREL_SHIFTER
757 
758   if (kIntegerBits > 5) {
759     static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
760     const InputF clamp =
761         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
762     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
763   }
764 
765   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
766   return result;
767 }
768 
769 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
770 
771 // Returns (1 - x) / (1 + x) for x in (0, 1).
772 template <typename tRawType>
773 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
774     FixedPoint<tRawType, 0> a) {
775   typedef FixedPoint<tRawType, 0> F0;
776   typedef FixedPoint<tRawType, 2> F2;
777   F0 half_denominator = RoundingHalfSum(a, F0::One());
778   // Newton-Raphson division
779   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
780   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
781   const F2 constant_48_over_17 =
782       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
783   const F2 constant_neg_32_over_17 =
784       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
785   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
786   for (int i = 0; i < 3; i++) {
787     F2 half_denominator_times_x = half_denominator * x;
788     F2 one_minus_half_denominator_times_x =
789         F2::One() - half_denominator_times_x;
790     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
791   }
792   return Rescale<0>(x - F2::One());
793 }
794 
795 // Returns -tanh(x) for x < 0.
796 template <typename tRawType, int tIntegerBits>
797 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
798     FixedPoint<tRawType, tIntegerBits> a) {
799   return one_minus_x_over_one_plus_x_for_x_in_0_1(
800       exp_on_negative_values(ExactMulByPot<1>(a)));
801 }
802 
803 // Returns tanh(x) for any x.
804 template <typename tRawType, int tIntegerBits>
805 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
806   typedef FixedPoint<tRawType, tIntegerBits> InputF;
807   typedef FixedPoint<tRawType, 0> ResultF;
808   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
809   tRawType mask_if_zero = MaskIfZero(a);
810   InputF n = SelectUsingMask(mask_if_negative, a, -a);
811   ResultF t = neg_tanh_on_negative_values(n);
812   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
813                          SelectUsingMask(mask_if_negative, -t, t));
814 }
815 
816 // Implementation of logistic function.
817 
818 // Returns 1 / (1 + x) for x in (0, 1).
819 template <typename tRawType>
820 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
821     FixedPoint<tRawType, 0> a) {
822   typedef FixedPoint<tRawType, 0> F0;
823   typedef FixedPoint<tRawType, 2> F2;
824   F0 half_denominator = RoundingHalfSum(a, F0::One());
825   // Newton-Raphson division
826   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
827   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
828   const F2 constant_48_over_17 =
829       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
830   const F2 constant_neg_32_over_17 =
831       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
832   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
833   for (int i = 0; i < 3; i++) {
834     F2 half_denominator_times_x = half_denominator * x;
835     F2 one_minus_half_denominator_times_x =
836         F2::One() - half_denominator_times_x;
837     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
838   }
839   return Rescale<0>(ExactMulByPot<-1>(x));
840 }
841 
842 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
843 template <typename tRawType, int tIntegerBits>
844 FixedPoint<tRawType, 0> logistic_on_positive_values(
845     FixedPoint<tRawType, tIntegerBits> a) {
846   return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
847 }
848 
849 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
850 template <typename tRawType, int tIntegerBits>
851 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
852   typedef FixedPoint<tRawType, tIntegerBits> InputF;
853   typedef FixedPoint<tRawType, 0> ResultF;
854   tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
855   tRawType mask_if_zero = MaskIfZero(a);
856   InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
857   ResultF result_if_positive = logistic_on_positive_values(abs_input);
858   ResultF result_if_negative = ResultF::One() - result_if_positive;
859   const ResultF one_half =
860       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
861   return SelectUsingMask(mask_if_zero, one_half,
862                          SelectUsingMask(mask_if_positive, result_if_positive,
863                                          result_if_negative));
864 }
865 
866 }  // end namespace gemmlowp
867 
868 #ifdef GEMMLOWP_NEON
869 #include "./fixedpoint_neon.h"
870 #elif defined(GEMMLOWP_SSE4)
871 #include "./fixedpoint_sse.h"
872 #elif defined(GEMMLOWP_MSA)
873 #include "./fixedpoint_msa.h"
874 #endif
875 
876 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
877