1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
18 
19 #include <complex>
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 // Disable clang-format to prevent 'FixedPoint' header from being included
22 // before 'Tensor' header on which it depends.
23 // clang-format off
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
25 // clang-format on
26 
27 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 // Single precision complex.
33 typedef std::complex<float> complex64;
34 // Double precision complex.
35 typedef std::complex<double> complex128;
36 
37 // We use Eigen's QInt implementations for our quantized int types.
38 typedef Eigen::QInt8 qint8;
39 typedef Eigen::QUInt8 quint8;
40 typedef Eigen::QInt32 qint32;
41 typedef Eigen::QInt16 qint16;
42 typedef Eigen::QUInt16 quint16;
43 
44 }  // namespace tensorflow
45 
46 
47 
48 
FloatToBFloat16(float float_val)49 static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
50 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
51     return *reinterpret_cast<tensorflow::bfloat16*>(
52         reinterpret_cast<uint16_t*>(&float_val));
53 #else
54     return *reinterpret_cast<tensorflow::bfloat16*>(
55         &(reinterpret_cast<uint16_t*>(&float_val)[1]));
56 #endif
57 }
58 
59 namespace Eigen {
60 // TODO(xpan): We probably need to overwrite more methods to have correct eigen
61 // behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen.
62 template <>
63 struct NumTraits<tensorflow::bfloat16>
64     : GenericNumTraits<tensorflow::bfloat16> {
65   enum {
66     IsInteger = 0,
67     IsSigned = 1,
68     RequireInitialization = 0
69   };
70   static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() {
71     return FloatToBFloat16(NumTraits<float>::highest());
72   }
73 
74   static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() {
75     return FloatToBFloat16(NumTraits<float>::lowest());
76   }
77 
78   static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() {
79     return FloatToBFloat16(NumTraits<float>::infinity());
80   }
81 
82   static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() {
83     return FloatToBFloat16(NumTraits<float>::quiet_NaN());
84   }
85 };
86 
87 
88 using ::tensorflow::operator==;
89 using ::tensorflow::operator!=;
90 
91 namespace numext {
92 
93 template <>
94 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log(
95     const tensorflow::bfloat16& x) {
96   return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x)));
97 }
98 
99 template <>
100 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp(
101     const tensorflow::bfloat16& x) {
102   return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x)));
103 }
104 
105 template <>
106 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
107     const tensorflow::bfloat16& x) {
108   return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x)));
109 }
110 
111 }  // namespace numext
112 }  // namespace Eigen
113 
114 #if defined(_MSC_VER) && !defined(__clang__)
115 namespace std {
116 template <>
117 struct hash<Eigen::half> {
118   std::size_t operator()(const Eigen::half& a) const {
119     return static_cast<std::size_t>(a.x);
120   }
121 };
122 }  // namespace std
123 #endif  // _MSC_VER
124 
125 #endif  // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
126