1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
18
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/bfloat16.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/platform/byte_order.h"
25 #include "tensorflow/core/platform/types.h"
26
27 // Note that the GPU cast functor templates need to be instantiated unlike the
28 // CPU ones, and hence their specializations are different than that for CPUs.
29 #ifdef SPECIALIZE_FOR_GPUS
30 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
31 template <typename Device> \
32 struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \
33 void operator()(const Device& d, \
34 typename TTypes<OUT_TYPE>::Flat out_tensor, \
35 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
36 bool truncate = false) { \
37 if (truncate) { \
38 out_tensor.device(d) = \
39 in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
40 .template cast<OUT_TYPE>(); \
41 } else { \
42 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
43 } \
44 } \
45 }; \
46 template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
47 #else
48 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
49 template <> \
50 struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \
51 void operator()(const DEVICE& d, \
52 typename TTypes<OUT_TYPE>::Flat out_tensor, \
53 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
54 bool truncate = false) { \
55 if (truncate) { \
56 out_tensor.device(d) = \
57 in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
58 .template cast<OUT_TYPE>(); \
59 } else { \
60 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
61 } \
62 } \
63 };
64 #endif
65
66 #define CAST_FUNCTORS(devname) \
67 SPECIALIZE_CAST(devname, float, double) \
68 SPECIALIZE_CAST(devname, float, std::complex<double>) \
69 SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
70 SPECIALIZE_CAST(devname, std::complex<float>, double) \
71 SPECIALIZE_CAST(devname, Eigen::half, double) \
72 SPECIALIZE_CAST(devname, Eigen::half, float) \
73 SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
74 SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
75 SPECIALIZE_CAST(devname, bfloat16, float) \
76 template <typename OUT_TYPE, typename IN_OUT> \
77 struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
78 void operator()(const devname& d, \
79 typename TTypes<OUT_TYPE>::Flat out_tensor, \
80 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
81 bool truncate = false) { \
82 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
83 } \
84 };
85
86 namespace tensorflow {
87
88 typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
89 bool trunc)>
90 CastFunctorType;
91
92 // Common base class of Cast kernels
93 class CastOpBase : public OpKernel {
94 public:
95 explicit CastOpBase(OpKernelConstruction* ctx);
96
97 void Compute(OpKernelContext* ctx) override;
98
99 protected:
100 DataType src_dtype_;
101 DataType dst_dtype_;
102 DataType external_src_dtype_;
103 DataType external_dst_dtype_;
104 bool use_truncation_;
105 CastFunctorType work_ = nullptr;
106 Status Unimplemented();
107
108 TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
109 };
110
111 // CPU implementation of Cast
112 class CpuCastOp : public CastOpBase {
113 public:
114 explicit CpuCastOp(OpKernelConstruction* ctx);
115
116 private:
117 Status Prepare();
118 };
119
120 namespace functor {
121
122 template <typename I>
MantissaWidth()123 constexpr int MantissaWidth() {
124 return std::numeric_limits<I>::digits;
125 }
126
127 template <>
128 constexpr int MantissaWidth<Eigen::half>() {
129 // Remember, there's 1 hidden bit
130 return 10 + 1;
131 }
132
133 template <>
134 constexpr int MantissaWidth<bfloat16>() {
135 // Remember, there's 1 hidden bit
136 return 7 + 1;
137 }
138
139 template <typename Device, typename Tout, typename Tin>
Cast(const Device & d,typename TTypes<Tout>::Flat o,typename TTypes<Tin>::ConstFlat i)140 void Cast(const Device& d, typename TTypes<Tout>::Flat o,
141 typename TTypes<Tin>::ConstFlat i) {
142 o.device(d) = i.template cast<Tout>();
143 }
144
145 template <typename Device, typename Tout, typename Tin>
146 struct CastFunctor {
147 void operator()(const Device& d, typename TTypes<Tout>::Flat o,
148 typename TTypes<Tin>::ConstFlat i, bool truncate = false);
149 };
150
151 // Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
152 // Specialize for others if needed in future.
153 template <typename I>
154 typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)155 EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
156 // Only zero the bits for non-NaNs.
157 // For NaNs, let the non-truncation version handle it.
158 if (!std::isnan(t)) {
159 uint64_t* p = reinterpret_cast<uint64_t*>(&t);
160 *p &= (0xFFFFFFFFFFFFFFFF << n);
161 }
162 }
163
164 template <typename I>
165 typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)166 EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
167 // Only zero the bits for non-NaNs.
168 // For NaNs, let the non-truncation version handle it.
169 if (!std::isnan(t)) {
170 uint32_t* p = reinterpret_cast<uint32_t*>(&t);
171 *p &= (0xFFFFFFFF << n);
172 }
173 }
174
175 // Set n least significant bits to 0
176 template <typename I, typename O>
177 struct LSBZeroSetter {
EIGEN_EMPTY_STRUCT_CTORLSBZeroSetter178 EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
179
180 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
181 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
182 static_assert(
183 bits > 0,
184 "The output type must have fewer mantissa bits than the input type\n");
185 I t = a;
186 LSBZeroSetterHelper(t, bits);
187 return t;
188 }
189 };
190
191 template <typename I, typename O>
192 struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
193 EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
194
195 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
196 const std::complex<I>& a) const {
197 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
198 static_assert(
199 bits > 0,
200 "The output type must have fewer mantissa bits than the input type\n");
201 I re = std::real(a);
202 I img = std::imag(a);
203 LSBZeroSetterHelper(re, bits);
204 LSBZeroSetterHelper(img, bits);
205 std::complex<I> toReturn(re, img);
206 return toReturn;
207 }
208 };
209
210 template <typename I, typename O>
211 struct LSBZeroSetter<std::complex<I>, O> {
212 EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
213 // Sets the 16 LSBits of the float to 0
214 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
215 const std::complex<I>& a) const {
216 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
217 static_assert(
218 bits > 0,
219 "The output type must have fewer mantissa bits than the input type\n");
220 I re = std::real(a);
221 I img = std::imag(a);
222 LSBZeroSetterHelper(re, bits);
223 LSBZeroSetterHelper(img, bits);
224 std::complex<I> toReturn(re, img);
225 return toReturn;
226 }
227 };
228
229 } // end namespace functor
230 } // end namespace tensorflow
231
232 namespace Eigen {
233 namespace internal {
234
235 // Eigen can't convert to/from complex numbers, because it is limited to cases
236 // that can be static_casted. But numpy is able to cast to/from complex, which
237 // we want to replicate. So we add specializations for complex here.
238 template <typename From, typename To>
239 struct scalar_cast_op<std::complex<From>, To> {
240 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
241 operator()(const std::complex<From>& a) const {
242 // Replicate numpy behavior of returning just the real part
243 return static_cast<To>(a.real());
244 }
245 };
246
247 template <typename From, typename To>
248 struct scalar_cast_op<From, std::complex<To>> {
249 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
250 const From& a) const {
251 // Replicate numpy behavior of setting the imaginary part to 0
252 return std::complex<To>(static_cast<To>(a), To(0));
253 }
254 };
255
256 template <typename From, typename To>
257 struct scalar_cast_op<std::complex<From>, std::complex<To>> {
258 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
259 const std::complex<From>& a) const {
260 return std::complex<To>(static_cast<To>(a.real()),
261 static_cast<To>(a.imag()));
262 }
263 };
264
265 template <typename From, typename To>
266 struct functor_traits_complex_impl {
267 enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
268 };
269
270 template <typename From, typename To>
271 struct functor_traits<scalar_cast_op<std::complex<From>, To>>
272 : functor_traits_complex_impl<std::complex<From>, To> {};
273 template <typename From, typename To>
274 struct functor_traits<scalar_cast_op<From, std::complex<To>>>
275 : functor_traits_complex_impl<From, std::complex<To>> {};
276 // Needed to avoid ambiguous partial specialization
277 template <typename From, typename To>
278 struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
279 : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
280
281 } // namespace internal
282 } // namespace Eigen
283
284 #endif // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
285