1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // fixedpoint.h: fixed-point arithmetic, with basic operations and
16 // a few math functions such as tanh.
17 
18 // This is only used in output.h
19 // for some specific output pipeline stages (tanh); most of gemmlowp
20 // uses only plain integer arithmetic, not fixed-point arithmetic.
21 // At the most basic level, we distinguish between plain integer
22 // arithmetic and fixed-point arithmetic by the type of multiplication
23 // that is used: plain integer arithmetic uses plain (overflowing)
24 // integer multiplication, whereas fixed-point arithmetic uses
25 // "multiply-high" instructions, which means using only the most
26 // significant bits of the product, or equivalently, multiplying
27 // fixed-point numbers in the [-1 .. +1] interval.
28 
29 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
30 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
31 
32 #include "common.h"
33 
34 #include <limits>
35 #include <cassert>
36 
37 namespace gemmlowp {
38 
39 template <typename tIntegerType>
BitAnd(tIntegerType a,tIntegerType b)40 tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
41   return a & b;
42 }
43 
44 template <typename tIntegerType>
BitOr(tIntegerType a,tIntegerType b)45 tIntegerType BitOr(tIntegerType a, tIntegerType b) {
46   return a | b;
47 }
48 
49 template <typename tIntegerType>
BitXor(tIntegerType a,tIntegerType b)50 tIntegerType BitXor(tIntegerType a, tIntegerType b) {
51   return a ^ b;
52 }
53 
54 template <typename tIntegerType>
BitNot(tIntegerType a)55 tIntegerType BitNot(tIntegerType a) {
56   return ~a;
57 }
58 
59 template <typename tIntegerType>
Add(tIntegerType a,tIntegerType b)60 tIntegerType Add(tIntegerType a, tIntegerType b) {
61   return a + b;
62 }
63 
64 template <typename tIntegerType>
Sub(tIntegerType a,tIntegerType b)65 tIntegerType Sub(tIntegerType a, tIntegerType b) {
66   return a - b;
67 }
68 
69 template <typename tIntegerType>
Neg(tIntegerType a)70 tIntegerType Neg(tIntegerType a) {
71   return -a;
72 }
73 
74 template <typename tIntegerType>
ShiftLeft(tIntegerType a,int offset)75 tIntegerType ShiftLeft(tIntegerType a, int offset) {
76   return a * (1 << offset);
77 }
78 
79 template <typename tIntegerType>
ShiftRight(tIntegerType a,int offset)80 tIntegerType ShiftRight(tIntegerType a, int offset) {
81   return a / (1 << offset);
82 }
83 
84 template <typename tIntegerType>
SelectUsingMask(tIntegerType if_mask,tIntegerType then_val,tIntegerType else_val)85 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
86                              tIntegerType else_val) {
87   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
88 }
89 
90 template <typename tIntegerType>
MaskIfNonZero(tIntegerType a)91 tIntegerType MaskIfNonZero(tIntegerType a) {
92   static const tIntegerType zero = 0;
93   return a ? BitNot(zero) : zero;
94 }
95 
96 template <typename tIntegerType>
MaskIfZero(tIntegerType a)97 tIntegerType MaskIfZero(tIntegerType a) {
98   return MaskIfNonZero<tIntegerType>(!a);
99 }
100 
101 template <typename tIntegerType>
MaskIfEqual(tIntegerType a,tIntegerType b)102 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
103   return MaskIfNonZero<tIntegerType>(a == b);
104 }
105 
106 template <typename tIntegerType>
MaskIfNotEqual(tIntegerType a,tIntegerType b)107 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
108   return MaskIfNonZero<tIntegerType>(a != b);
109 }
110 
111 template <typename tIntegerType>
MaskIfGreaterThan(tIntegerType a,tIntegerType b)112 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
113   return MaskIfNonZero<tIntegerType>(a > b);
114 }
115 
116 template <typename tIntegerType>
MaskIfGreaterThanOrEqual(tIntegerType a,tIntegerType b)117 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
118   return MaskIfNonZero<tIntegerType>(a >= b);
119 }
120 
121 template <typename tIntegerType>
MaskIfLessThan(tIntegerType a,tIntegerType b)122 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
123   return MaskIfNonZero<tIntegerType>(a < b);
124 }
125 
126 template <typename tIntegerType>
MaskIfLessThanOrEqual(tIntegerType a,tIntegerType b)127 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
128   return MaskIfNonZero<tIntegerType>(a <= b);
129 }
130 
131 template <typename tIntegerType>
All(tIntegerType a)132 bool All(tIntegerType a) {
133   return a;
134 }
135 
136 template <typename tIntegerType>
Any(tIntegerType a)137 bool Any(tIntegerType a) {
138   return a;
139 }
140 
141 template <typename IntegerType>
RoundingHalfSum(IntegerType a,IntegerType b)142 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
143   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
144   return a;
145 }
146 
147 template <>
RoundingHalfSum(int32_t a,int32_t b)148 inline int32_t RoundingHalfSum(int32_t a, int32_t b) {
149   int64_t a64 = a;
150   int64_t b64 = b;
151   int64_t sum = a64 + b64;
152   int64_t sign = sum >= 0 ? 1 : -1;
153   return static_cast<int32_t>((sum + sign) / 2);
154 }
155 
156 template <typename IntegerType>
SaturatingRoundingDoublingHighMul(IntegerType a,IntegerType b)157 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
158   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
159   return a;
160 }
161 
162 // This function implements the same computation as the ARMv7 NEON VQRDMULH
163 // instruction.
164 template <>
SaturatingRoundingDoublingHighMul(int32_t a,int32_t b)165 inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) {
166   bool overflow = a == b && a == std::numeric_limits<int32_t>::min();
167   int64_t a_64(a);
168   int64_t b_64(b);
169   int64_t ab_64 = a_64 * b_64;
170   int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
171   int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31));
172   return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32;
173 }
174 
175 template <int Exponent, typename IntegerType,
176           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
177 struct ImplSaturatingRoundingMultiplyByPOT {};
178 
179 template <int Exponent, typename IntegerType>
180 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
181   static IntegerType eval(IntegerType x) { return x; }
182 };
183 
184 template <int Exponent>
185 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, 1> {
186   static int32_t eval(int32_t x) {
187     const int64_t min = std::numeric_limits<int32_t>::min();
188     const int64_t max = std::numeric_limits<int32_t>::max();
189     return x >= (1 << (31 - Exponent)) ? max : x <= -(1 << (31 - Exponent))
190                                                    ? min
191                                                    : x * (1 << Exponent);
192   }
193 };
194 
195 template <int Exponent>
196 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, -1> {
197   static int32_t eval(int32_t x) {
198     int32_t b = (std::abs(x) & (1 << (-Exponent - 1))) >> (-Exponent - 1);
199     int32_t nudge = x >= 0 ? b : -b;
200     return x / (1 << -Exponent) + nudge;
201   }
202 };
203 
204 template <int Exponent, typename IntegerType>
205 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
206   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
207 }
208 
209 template <typename tIntegerType>
210 struct FixedPointRawTypeTraits {};
211 
212 template <>
213 struct FixedPointRawTypeTraits<int32_t> {
214   typedef int32_t ScalarRawType;
215   static const int kLanes = 1;
216 };
217 
218 template <typename tRawType>
219 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
220   return x;
221 }
222 
223 template <typename tRawType, int tIntegerBits>
224 class FixedPoint {
225  public:
226   typedef tRawType RawType;
227 
228   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
229   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
230 
231   static const int kTotalBits = 8 * sizeof(ScalarRawType);
232   static const int kIntegerBits = tIntegerBits;
233   static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
234   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
235                 "bad IntegerBits");
236 
237   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
238 
239   static const ScalarRawType ScalarRawMin() {
240     return std::numeric_limits<ScalarRawType>::min();
241   }
242 
243   static const ScalarRawType ScalarRawMax() {
244     return std::numeric_limits<ScalarRawType>::max();
245   }
246 
247   static const ScalarRawType RawMin() {
248     return VectorFromScalar(ScalarRawMin());
249   }
250 
251   static const ScalarRawType RawMax() {
252     return VectorFromScalar(ScalarRawMax());
253   }
254 
255   static FixedPoint FromRaw(RawType x) {
256     FixedPoint retval;
257     retval.raw() = x;
258     return retval;
259   }
260 
261   static FixedPoint FromScalarRaw(ScalarRawType x) {
262     FixedPoint retval;
263     retval.raw() = Dup<RawType>(x);
264     return retval;
265   }
266 
267   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
268     return FromScalarRaw(x.raw());
269   }
270 
271   template <int Exponent>
272   static FixedPoint ConstantPOT() {
273     static const int kOffset = kFractionalBits + Exponent;
274     static_assert(
275         kOffset < 31,
276         "Constant not exactly representable in this fixed-point format");
277     return FromScalarRaw(ScalarRawType(1) << kOffset);
278   }
279 
280   static FixedPoint Zero() { return FromScalarRaw(0); }
281 
282   static FixedPoint One() {
283     return FromScalarRaw(kIntegerBits == 0
284                              ? ScalarRawMax()
285                              : (ScalarRawType(1) << kFractionalBits));
286   }
287 
288   RawType raw() const { return i_; }
289   RawType& raw() { return i_; }
290 
291  private:
292   RawType i_;
293 };
294 
295 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
296 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
297     FixedPoint<tRawType, tIntegerBits_a> a,
298     FixedPoint<tRawType, tIntegerBits_b> b) {
299   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
300   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
301   return c;
302 }
303 
304 template <int tExponent, typename tRawType, int tIntegerBits>
305 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
306     FixedPoint<tRawType, tIntegerBits> a) {
307   FixedPoint<tRawType, tExponent + tIntegerBits> c;
308   c.raw() = a.raw();
309   return c;
310 }
311 
312 template <int tExponent, typename tRawType, int tIntegerBits>
313 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
314     FixedPoint<tRawType, tIntegerBits> a) {
315   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
316       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
317 }
318 
319 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
320   template <typename tRawType, int tIntegerBits>                               \
321   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
322       FixedPoint<tRawType, tIntegerBits> a) {                                  \
323     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
324   }
325 
326 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
327   template <typename tRawType, int tIntegerBits>            \
328   FixedPoint<tRawType, tIntegerBits> FuncName(              \
329       FixedPoint<tRawType, tIntegerBits> a,                 \
330       FixedPoint<tRawType, tIntegerBits> b) {               \
331     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
332         ImplFuncName(a.raw(), b.raw()));                    \
333   }
334 
335 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
336 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
337 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
338 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
339 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
340 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
341 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
342 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
343 
344 #undef MAKE_FIXEDPOINT_UNARY_FUNC
345 #undef MAKE_FIXEDPOINT_BINARY_FUNC
346 
347 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
348   template <typename tRawType, int tIntegerBits>            \
349   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
350     return FuncName(a.raw());                               \
351   }
352 
353 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
354   template <typename tRawType, int tIntegerBits>            \
355   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
356                     FixedPoint<tRawType, tIntegerBits> b) { \
357     return FuncName(a.raw(), b.raw());                      \
358   }
359 
360 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
361 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
362 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
363 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
364 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
365 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
366 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
367 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
368 
369 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
370 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
371 
372 template <typename tRawType, int tIntegerBits>
373 FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
374     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
375     FixedPoint<tRawType, tIntegerBits> else_val) {
376   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
377       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
378 }
379 
380 template <typename tRawType, int tIntegerBits>
381 bool operator==(FixedPoint<tRawType, tIntegerBits> a,
382                 FixedPoint<tRawType, tIntegerBits> b) {
383   return All(MaskIfEqual(a.raw(), b.raw()));
384 }
385 
386 template <typename tRawType, int tIntegerBits>
387 bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
388                 FixedPoint<tRawType, tIntegerBits> b) {
389   return !(a == b);
390 }
391 
392 template <typename tRawType, int tIntegerBits>
393 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
394   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
395                 "not applicable to SIMD types");
396   typedef FixedPoint<tRawType, tIntegerBits> F;
397   return x.raw() / double(1ll << F::kFractionalBits);
398 }
399 
400 template <typename tRawType, int tIntegerBits>
401 FixedPoint<tRawType, tIntegerBits> ToFixedPoint(double x) {
402   typedef FixedPoint<tRawType, tIntegerBits> F;
403   return F::FromScalarRaw(static_cast<int32_t>(
404       std::min(std::max(round(x * double(1ll << F::kFractionalBits)),
405                         double(F::ScalarRawMin())),
406                double(F::ScalarRawMax()))));
407 }
408 
409 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
410 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
411     FixedPoint<tRawType, tIntegerBitsSrc> x) {
412   static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
413   FixedPoint<tRawType, tIntegerBitsDst> result;
414   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
415   return result;
416 }
417 
418 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
419 template <typename FixedPointType>
420 FixedPointType CheckedFixedPointConstant(
421     typename FixedPointType::ScalarRawType raw_value, double double_value) {
422   typedef typename FixedPointType::RawType RawType;
423   static const int kIntegerBits = FixedPointType::kIntegerBits;
424   FixedPointType ref = FixedPointType::FromScalarRaw(raw_value);
425   FixedPointType check = ToFixedPoint<RawType, kIntegerBits>(double_value);
426   assert(ref == check);
427   return ref;
428 }
429 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
430                                              DoubleValue)                    \
431   (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue))
432 
433 #else
434 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
435                                              DoubleValue)                    \
436   (FixedPointType::FromScalarRaw(ScalarRawValue))
437 #endif
438 
439 template <typename tRawType>
440 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
441     FixedPoint<tRawType, 0> a) {
442   typedef FixedPoint<tRawType, 0> F;
443   const F constant_term =
444       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
445   const F constant_1_over_3 =
446       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
447   // We're evaluating a Taylor expansion around -1/8, so we do the change of
448   // variable: x = a + 1/8.
449   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
450   F x = a + F::template ConstantPOT<-3>();
451   F x2 = x * x;
452   F x3 = x2 * x;
453   F x4 = x2 * x2;
454   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
455   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
456       SaturatingRoundingMultiplyByPOT<-1>(
457           ((x4_over_4 + x3) * constant_1_over_3) + x2);
458   return constant_term +
459          constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
460 }
461 
462 template <typename tRawType, int tIntegerBits>
463 FixedPoint<tRawType, 0> exp_on_negative_values(
464     FixedPoint<tRawType, tIntegerBits> a) {
465   typedef FixedPoint<tRawType, tIntegerBits> InputF;
466   typedef FixedPoint<tRawType, 0> ResultF;
467   static const int kFractionalBits = InputF::kFractionalBits;
468   static const int kIntegerBits = InputF::kIntegerBits;
469   static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
470   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
471   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
472   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
473       Rescale<0>(a_mod_quarter_minus_one_quarter));
474   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
475 
476 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
477   if (kIntegerBits > Exponent) {                                            \
478     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
479         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
480     result = SelectUsingMask(                                               \
481         MaskIfNonZero(BitAnd(                                               \
482             remainder, Dup<tRawType>(1 << (kFractionalBits + Exponent)))),  \
483         result * kMultiplier, result);                                      \
484   }
485 
486   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
487   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
488   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
489   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
490   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
491   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
492   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
493 
494 #undef GEMMLOWP_EXP_BARREL_SHIFTER
495 
496   if (kIntegerBits > 5) {
497     static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0;
498     const InputF clamp =
499         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
500     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
501   }
502 
503   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
504   return result;
505 }
506 
507 template <typename tRawType>
508 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
509     FixedPoint<tRawType, 0> a) {
510   typedef FixedPoint<tRawType, 0> F0;
511   typedef FixedPoint<tRawType, 2> F2;
512   F0 half_denominator = RoundingHalfSum(a, F0::One());
513   const F2 constant_48_over_17 =
514       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
515   const F2 constant_neg_32_over_17 =
516       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
517   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
518   for (int i = 0; i < 3; i++) {
519     F2 half_denominator_times_x = half_denominator * x;
520     F2 one_minus_half_denominator_times_x =
521         F2::One() - half_denominator_times_x;
522     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
523   }
524   return Rescale<0>(x - F2::One());
525 }
526 
527 template <typename tRawType, int tIntegerBits>
528 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
529     FixedPoint<tRawType, tIntegerBits> a) {
530   return one_minus_x_over_one_plus_x_for_x_in_0_1(
531       exp_on_negative_values(ExactMulByPot<1>(a)));
532 }
533 
534 template <typename tRawType, int tIntegerBits>
535 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
536   typedef FixedPoint<tRawType, tIntegerBits> InputF;
537   typedef FixedPoint<tRawType, 0> ResultF;
538   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
539   tRawType mask_if_zero = MaskIfZero(a);
540   InputF n = SelectUsingMask(mask_if_negative, a, -a);
541   ResultF t = neg_tanh_on_negative_values(n);
542   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
543                          SelectUsingMask(mask_if_negative, -t, t));
544 }
545 
546 }  // end namespace gemmlowp
547 
548 #ifdef GEMMLOWP_NEON
549 #include "fixedpoint_neon.h"
550 #endif
551 
552 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
553