1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
17 #define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
18 
19 #include <cmath>
20 #include <complex>
21 
22 #include "tensorflow/core/platform/byte_order.h"
23 
24 #ifdef __CUDACC__
25 // All functions callable from CUDA code must be qualified with __device__
26 #define B16_DEVICE_FUNC __host__ __device__
27 
28 #else
29 #define B16_DEVICE_FUNC
30 
31 #endif
32 
33 namespace Eigen {
34 struct half;
35 }
36 
37 namespace tensorflow {
38 
39 // Single precision complex.
40 typedef std::complex<float> complex64;
41 // Double precision complex.
42 typedef std::complex<double> complex128;
43 
44 // see framework/bfloat16.h for description.
45 struct bfloat16 {
46   // The default constructor must yield a zero value, not an uninitialized
47   // value; some TF kernels use T() as a zero value.
bfloat16bfloat1648   B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {}
49 
truncate_to_bfloat16bfloat1650   B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
51     bfloat16 output;
52     if (float_isnan(v)) {
53       output.value = NAN_VALUE;
54       return output;
55     }
56     const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
57 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
58     output.value = p[0];
59 #else
60     output.value = p[1];
61 #endif
62     return output;
63   }
64 
bfloat16bfloat1665   B16_DEVICE_FUNC explicit bfloat16(const float v) {
66     value = round_to_bfloat16(v).value;
67   }
68 
bfloat16bfloat1669   B16_DEVICE_FUNC explicit bfloat16(const double val)
70       : bfloat16(static_cast<float>(val)) {}
71   // Following the convention of numpy, converting between complex and
72   // float will lead to loss of imag value.
bfloat16bfloat1673   B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
74       : bfloat16(val.real()) {}
75 
bfloat16bfloat1676   B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
77       : bfloat16(static_cast<float>(val.real())) {}
78 
bfloat16bfloat1679   B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
80       : bfloat16(static_cast<float>(val)) {}
81 
bfloat16bfloat1682   B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
83       : bfloat16(static_cast<float>(val)) {}
84 
bfloat16bfloat1685   B16_DEVICE_FUNC explicit bfloat16(const int val)
86       : bfloat16(static_cast<float>(val)) {}
87 
bfloat16bfloat1688   B16_DEVICE_FUNC explicit bfloat16(const long val)
89       : bfloat16(static_cast<float>(val)) {}
90 
bfloat16bfloat1691   B16_DEVICE_FUNC explicit bfloat16(const long long val)
92       : bfloat16(static_cast<float>(val)) {}
93 
94   template <class T>
bfloat16bfloat1695   B16_DEVICE_FUNC explicit bfloat16(const T& val)
96       : bfloat16(static_cast<float>(val)) {}
97 
98   B16_DEVICE_FUNC explicit operator float() const {
99     float result = 0;
100 
101     uint16_t* q = reinterpret_cast<uint16_t*>(&result);
102 
103 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
104     q[0] = value;
105 #else
106     q[1] = value;
107 #endif
108     return result;
109   }
110 
111   B16_DEVICE_FUNC explicit operator bool() const {
112     return static_cast<bool>(float(*this));
113   }
114 
115   B16_DEVICE_FUNC explicit operator Eigen::half() const;
116 
117   B16_DEVICE_FUNC explicit operator short() const {
118     return static_cast<short>(float(*this));
119   }
120 
121   B16_DEVICE_FUNC explicit operator int() const {
122     return static_cast<int>(float(*this));
123   }
124 
125   B16_DEVICE_FUNC explicit operator long() const {
126     return static_cast<long>(float(*this));
127   }
128 
129   B16_DEVICE_FUNC explicit operator char() const {
130     return static_cast<char>(float(*this));
131   }
132 
133   B16_DEVICE_FUNC explicit operator signed char() const {
134     return static_cast<signed char>(float(*this));
135   }
136 
137   B16_DEVICE_FUNC explicit operator unsigned char() const {
138     return static_cast<unsigned char>(float(*this));
139   }
140 
141   B16_DEVICE_FUNC explicit operator unsigned short() const {
142     return static_cast<unsigned short>(float(*this));
143   }
144 
145   B16_DEVICE_FUNC explicit operator unsigned int() const {
146     return static_cast<unsigned int>(float(*this));
147   }
148 
149   B16_DEVICE_FUNC explicit operator unsigned long() const {
150     return static_cast<unsigned long>(float(*this));
151   }
152 
153   B16_DEVICE_FUNC explicit operator unsigned long long() const {
154     return static_cast<unsigned long long>(float(*this));
155   }
156 
157   B16_DEVICE_FUNC explicit operator long long() const {
158     return static_cast<long long>(float(*this));
159   }
160 
161   B16_DEVICE_FUNC explicit operator double() const {
162     return static_cast<double>(float(*this));
163   }
164 
complex64bfloat16165   B16_DEVICE_FUNC explicit operator complex64() const {
166     return complex64(float(*this), float(0.0));
167   }
168 
complex128bfloat16169   B16_DEVICE_FUNC explicit operator complex128() const {
170     return complex128(double(*this), double(0.0));
171   }
172 
173   union FP32 {
174     unsigned int u;
175     float f;
176   };
177 
178   // Converts a float point to bfloat16, with round-nearest-to-even as rounding
179   // method.
180   // TODO: There is a slightly faster implementation (8% faster on CPU)
181   // than this (documented in cl/175987786), that is exponentially harder to
182   // understand and document. Switch to the faster version when converting to
183   // BF16 becomes compute-bound.
round_to_bfloat16bfloat16184   B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
185     uint32_t input;
186     FP32 f;
187     f.f = v;
188     input = f.u;
189     bfloat16 output;
190 
191     if (float_isnan(v)) {
192       // If the value is a NaN, squash it to a qNaN with msb of fraction set,
193       // this makes sure after truncation we don't end up with an inf.
194       //
195       // qNaN magic: All exponent bits set + most significant bit of fraction
196       // set.
197       output.value = 0x7fc0;
198     } else {
199       // Fast rounding algorithm that rounds a half value to nearest even. This
200       // reduces expected error when we convert a large number of floats. Here
201       // is how it works:
202       //
203       // Definitions:
204       // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
205       // with the following tags:
206       //
207       // Sign |  Exp (8 bits) | Frac (23 bits)
208       //  S     EEEEEEEE         FFFFFFLRTTTTTTTTTTTTTTT
209       //
210       //  S: Sign bit.
211       //  E: Exponent bits.
212       //  F: First 6 bits of fraction.
213       //  L: Least significant bit of resulting bfloat16 if we truncate away the
214       //  rest of the float32. This is also the 7th bit of fraction
215       //  R: Rounding bit, 8th bit of fraction.
216       //  T: Sticky bits, rest of fraction, 15 bits.
217       //
218       // To round half to nearest even, there are 3 cases where we want to round
219       // down (simply truncate the result of the bits away, which consists of
220       // rounding bit and sticky bits) and two cases where we want to round up
221       // (truncate then add one to the result).
222       //
223       // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
224       // 1s) as the rounding bias, adds the rounding bias to the input, then
225       // truncates the last 16 bits away.
226       //
227       // To understand how it works, we can analyze this algorithm case by case:
228       //
229       // 1. L = 0, R = 0:
230       //   Expect: round down, this is less than half value.
231       //
232       //   Algorithm:
233       //   - Rounding bias: 0x7fff + 0 = 0x7fff
234       //   - Adding rounding bias to input may create any carry, depending on
235       //   whether there is any value set to 1 in T bits.
236       //   - R may be set to 1 if there is a carry.
237       //   - L remains 0.
238       //   - Note that this case also handles Inf and -Inf, where all fraction
239       //   bits, including L, R and Ts are all 0. The output remains Inf after
240       //   this algorithm.
241       //
242       // 2. L = 1, R = 0:
243       //   Expect: round down, this is less than half value.
244       //
245       //   Algorithm:
246       //   - Rounding bias: 0x7fff + 1 = 0x8000
247       //   - Adding rounding bias to input doesn't change sticky bits but
248       //   adds 1 to rounding bit.
249       //   - L remains 1.
250       //
251       // 3. L = 0, R = 1, all of T are 0:
252       //   Expect: round down, this is exactly at half, the result is already
253       //   even (L=0).
254       //
255       //   Algorithm:
256       //   - Rounding bias: 0x7fff + 0 = 0x7fff
257       //   - Adding rounding bias to input sets all sticky bits to 1, but
258       //   doesn't create a carry.
259       //   - R remains 1.
260       //   - L remains 0.
261       //
262       // 4. L = 1, R = 1:
263       //   Expect: round up, this is exactly at half, the result needs to be
264       //   round to the next even number.
265       //
266       //   Algorithm:
267       //   - Rounding bias: 0x7fff + 1 = 0x8000
268       //   - Adding rounding bias to input doesn't change sticky bits, but
269       //   creates a carry from rounding bit.
270       //   - The carry sets L to 0, creates another carry bit and propagate
271       //   forward to F bits.
272       //   - If all the F bits are 1, a carry then propagates to the exponent
273       //   bits, which then creates the minimum value with the next exponent
274       //   value. Note that we won't have the case where exponents are all 1,
275       //   since that's either a NaN (handled in the other if condition) or inf
276       //   (handled in case 1).
277       //
278       // 5. L = 0, R = 1, any of T is 1:
279       //   Expect: round up, this is greater than half.
280       //
281       //   Algorithm:
282       //   - Rounding bias: 0x7fff + 0 = 0x7fff
283       //   - Adding rounding bias to input creates a carry from sticky bits,
284       //   sets rounding bit to 0, then create another carry.
285       //   - The second carry sets L to 1.
286       //
287       // Examples:
288       //
289       //  Exact half value that is already even:
290       //    Input:
291       //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
292       //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
293       //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0     1000000000000000
294       //
295       //     This falls into case 3. We truncate the rest of 16 bits and no
296       //     carry is created into F and L:
297       //
298       //    Output:
299       //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
300       //     S     E E E E E E E E      F F F F F F L
301       //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
302       //
303       //  Exact half value, round to next even number:
304       //    Input:
305       //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
306       //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
307       //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 0 1     1000000000000000
308       //
309       //     This falls into case 4. We create a carry from R and T,
310       //     which then propagates into L and F:
311       //
312       //    Output:
313       //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
314       //     S     E E E E E E E E      F F F F F F L
315       //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
316       //
317       //
318       //  Max denormal value round to min normal value:
319       //    Input:
320       //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
321       //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
322       //     0     0 0 0 0 0 0 0 0      1 1 1 1 1 1 1     1111111111111111
323       //
324       //     This falls into case 4. We create a carry from R and T,
325       //     propagate into L and F, which then propagates into exponent
326       //     bits:
327       //
328       //    Output:
329       //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
330       //     S     E E E E E E E E      F F F F F F L
331       //     0     0 0 0 0 0 0 0 1      0 0 0 0 0 0 0
332       //
333       //  Max normal value round to Inf:
334       //    Input:
335       //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
336       //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
337       //     0     1 1 1 1 1 1 1 0      1 1 1 1 1 1 1     1111111111111111
338       //
339       //     This falls into case 4. We create a carry from R and T,
340       //     propagate into L and F, which then propagates into exponent
341       //     bits:
342       //
343       //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
344       //     S     E E E E E E E E      F F F F F F L
345       //     0     1 1 1 1 1 1 1 1      0 0 0 0 0 0 0
346       //
347       //
348       // Least significant bit of resulting bfloat.
349       uint32_t lsb = (input >> 16) & 1;
350       uint32_t rounding_bias = 0x7fff + lsb;
351       input += rounding_bias;
352       output.value = static_cast<uint16_t>(input >> 16);
353     }
354     return output;
355   }
356 
epsilonbfloat16357   static bfloat16 epsilon() {
358     bfloat16 x;
359     x.value = 0x3c00;  // 0x1.0p-7
360     return x;
361   }
362 
highestbfloat16363   static bfloat16 highest() {
364     bfloat16 x;
365     x.value = 0x7F7F;  // 0x1.FEp127
366     return x;
367   }
368 
lowestbfloat16369   static bfloat16 lowest() {
370     bfloat16 x;
371     x.value = 0xFF7F;  // -0x1.FEp127
372     return x;
373   }
374 
min_positive_normalbfloat16375   static bfloat16 min_positive_normal() {
376     bfloat16 x;
377     x.value = 0x0080;  // 0x1p-126
378     return x;
379   }
380 
IsZerobfloat16381   bool IsZero() const { return (value & 0x7FFF) == ZERO_VALUE; }
382 
383   uint16_t value;
384 
385   // A value that represents "not a number".
386   static const uint16_t NAN_VALUE = 0x7FC0;
387 
388  private:
389   // A value that represents "zero".
390   static const uint16_t ZERO_VALUE = 0;
391 
float_isnanbfloat16392   B16_DEVICE_FUNC static bool float_isnan(const float& x) {
393 #ifdef __CUDA_ARCH__
394     return ::isnan(x);
395 #else
396     return std::isnan(x);
397 #endif
398   }
399 };
400 
401 B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
402                                                 const bfloat16& dt) {
403   os << static_cast<float>(dt);
404   return os;
405 }
406 
407 B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
408   return bfloat16(static_cast<float>(a) + static_cast<float>(b));
409 }
410 B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
411   return bfloat16(static_cast<float>(a) + static_cast<float>(b));
412 }
413 B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
414   return bfloat16(static_cast<float>(a) + static_cast<float>(b));
415 }
416 B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
417   return bfloat16(static_cast<float>(a) - static_cast<float>(b));
418 }
419 B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
420   return bfloat16(static_cast<float>(a) * static_cast<float>(b));
421 }
422 B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
423   return bfloat16(static_cast<float>(a) / static_cast<float>(b));
424 }
425 B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
426   a.value ^= 0x8000;
427   return a;
428 }
429 B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
430   return static_cast<float>(a) < static_cast<float>(b);
431 }
432 B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
433   return static_cast<float>(a) <= static_cast<float>(b);
434 }
435 B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
436   return static_cast<float>(a) == static_cast<float>(b);
437 }
438 B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
439   return static_cast<float>(a) != static_cast<float>(b);
440 }
441 B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
442   return static_cast<float>(a) > static_cast<float>(b);
443 }
444 B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
445   return static_cast<float>(a) >= static_cast<float>(b);
446 }
447 B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
448   a = a + b;
449   return a;
450 }
451 B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
452   a = a - b;
453   return a;
454 }
455 B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
456   a += bfloat16(1);
457   return a;
458 }
459 B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
460   a -= bfloat16(1);
461   return a;
462 }
463 B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
464   bfloat16 original_value = a;
465   ++a;
466   return original_value;
467 }
468 B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
469   bfloat16 original_value = a;
470   --a;
471   return original_value;
472 }
473 B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
474   a = a * b;
475   return a;
476 }
477 B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
478   a = a / b;
479   return a;
480 }
481 }  // end namespace tensorflow
482 
483 namespace std {
484 template <>
485 struct hash<tensorflow::bfloat16> {
486   size_t operator()(const tensorflow::bfloat16& v) const {
487     return hash<float>()(static_cast<float>(v));
488   }
489 };
490 
491 using tensorflow::bfloat16;
492 inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); }
493 inline bool isnan(const bfloat16& a) { return std::isnan(float(a)); }
494 inline bool isfinite(const bfloat16& a) { return std::isfinite(float(a)); }
495 inline bfloat16 abs(const bfloat16& a) { return bfloat16(std::abs(float(a))); }
496 inline bfloat16 exp(const bfloat16& a) { return bfloat16(std::exp(float(a))); }
497 inline bfloat16 log(const bfloat16& a) { return bfloat16(std::log(float(a))); }
498 inline bfloat16 log10(const bfloat16& a) {
499   return bfloat16(std::log10(float(a)));
500 }
501 inline bfloat16 sqrt(const bfloat16& a) {
502   return bfloat16(std::sqrt(float(a)));
503 }
504 inline bfloat16 pow(const bfloat16& a, const bfloat16& b) {
505   return bfloat16(std::pow(float(a), float(b)));
506 }
507 inline bfloat16 sin(const bfloat16& a) { return bfloat16(std::sin(float(a))); }
508 inline bfloat16 cos(const bfloat16& a) { return bfloat16(std::cos(float(a))); }
509 inline bfloat16 tan(const bfloat16& a) { return bfloat16(std::tan(float(a))); }
510 inline bfloat16 tanh(const bfloat16& a) {
511   return bfloat16(std::tanh(float(a)));
512 }
513 inline bfloat16 floor(const bfloat16& a) {
514   return bfloat16(std::floor(float(a)));
515 }
516 inline bfloat16 ceil(const bfloat16& a) {
517   return bfloat16(std::ceil(float(a)));
518 }
519 }  // namespace std
520 
521 #endif  // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
522