1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/bfloat16.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/platform/byte_order.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 // Note that the GPU cast functor templates need to be instantiated unlike the
28 // CPU ones, and hence their specializations are different than that for CPUs.
29 #ifdef SPECIALIZE_FOR_GPUS
30 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT)                   \
31   template <typename Device>                                        \
32   struct CastFunctor<Device, OUT_TYPE, IN_OUT> {                    \
33     void operator()(const Device& d,                                \
34                     typename TTypes<OUT_TYPE>::Flat out_tensor,     \
35                     typename TTypes<IN_OUT>::ConstFlat in_tensor,   \
36                     bool truncate = false) {                        \
37       if (truncate) {                                               \
38         out_tensor.device(d) =                                      \
39             in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>())  \
40                 .template cast<OUT_TYPE>();                         \
41       } else {                                                      \
42         out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
43       }                                                             \
44     }                                                               \
45   };                                                                \
46   template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
47 #else
48 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT)                   \
49   template <>                                                       \
50   struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> {                    \
51     void operator()(const DEVICE& d,                                \
52                     typename TTypes<OUT_TYPE>::Flat out_tensor,     \
53                     typename TTypes<IN_OUT>::ConstFlat in_tensor,   \
54                     bool truncate = false) {                        \
55       if (truncate) {                                               \
56         out_tensor.device(d) =                                      \
57             in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>())  \
58                 .template cast<OUT_TYPE>();                         \
59       } else {                                                      \
60         out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
61       }                                                             \
62     }                                                               \
63   };
64 #endif
65 
66 #define CAST_FUNCTORS(devname)                                        \
67   SPECIALIZE_CAST(devname, float, double)                             \
68   SPECIALIZE_CAST(devname, float, std::complex<double>)               \
69   SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
70   SPECIALIZE_CAST(devname, std::complex<float>, double)               \
71   SPECIALIZE_CAST(devname, Eigen::half, double)                       \
72   SPECIALIZE_CAST(devname, Eigen::half, float)                        \
73   SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>)         \
74   SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>)          \
75   SPECIALIZE_CAST(devname, bfloat16, float)                           \
76   template <typename OUT_TYPE, typename IN_OUT>                       \
77   struct CastFunctor<devname, OUT_TYPE, IN_OUT> {                     \
78     void operator()(const devname& d,                                 \
79                     typename TTypes<OUT_TYPE>::Flat out_tensor,       \
80                     typename TTypes<IN_OUT>::ConstFlat in_tensor,     \
81                     bool truncate = false) {                          \
82       out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>();     \
83     }                                                                 \
84   };
85 
86 namespace tensorflow {
87 
88 typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
89                            bool trunc)>
90     CastFunctorType;
91 
92 // Common base class of Cast kernels
93 class CastOpBase : public OpKernel {
94  public:
95   explicit CastOpBase(OpKernelConstruction* ctx);
96 
97   void Compute(OpKernelContext* ctx) override;
98 
99  protected:
100   DataType src_dtype_;
101   DataType dst_dtype_;
102   DataType external_src_dtype_;
103   DataType external_dst_dtype_;
104   bool use_truncation_;
105   CastFunctorType work_ = nullptr;
106   Status Unimplemented();
107 
108   TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
109 };
110 
111 // CPU implementation of Cast
112 class CpuCastOp : public CastOpBase {
113  public:
114   explicit CpuCastOp(OpKernelConstruction* ctx);
115 
116  private:
117   Status Prepare();
118 };
119 
120 namespace functor {
121 
122 template <typename I>
MantissaWidth()123 constexpr int MantissaWidth() {
124   return std::numeric_limits<I>::digits;
125 }
126 
127 template <>
128 constexpr int MantissaWidth<Eigen::half>() {
129   // Remember, there's 1 hidden bit
130   return 10 + 1;
131 }
132 
133 template <>
134 constexpr int MantissaWidth<bfloat16>() {
135   // Remember, there's 1 hidden bit
136   return 7 + 1;
137 }
138 
139 template <typename Device, typename Tout, typename Tin>
Cast(const Device & d,typename TTypes<Tout>::Flat o,typename TTypes<Tin>::ConstFlat i)140 void Cast(const Device& d, typename TTypes<Tout>::Flat o,
141           typename TTypes<Tin>::ConstFlat i) {
142   o.device(d) = i.template cast<Tout>();
143 }
144 
145 template <typename Device, typename Tout, typename Tin>
146 struct CastFunctor {
147   void operator()(const Device& d, typename TTypes<Tout>::Flat o,
148                   typename TTypes<Tin>::ConstFlat i, bool truncate = false);
149 };
150 
151 // Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
152 // Specialize for others if needed in future.
153 template <typename I>
154 typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)155     EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
156   // Only zero the bits for non-NaNs.
157   // For NaNs, let the non-truncation version handle it.
158   if (!std::isnan(t)) {
159     uint64_t* p = reinterpret_cast<uint64_t*>(&t);
160     *p &= (0xFFFFFFFFFFFFFFFF << n);
161   }
162 }
163 
164 template <typename I>
165 typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)166     EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
167   // Only zero the bits for non-NaNs.
168   // For NaNs, let the non-truncation version handle it.
169   if (!std::isnan(t)) {
170     uint32_t* p = reinterpret_cast<uint32_t*>(&t);
171     *p &= (0xFFFFFFFF << n);
172   }
173 }
174 
175 // Set n least significant bits to 0
176 template <typename I, typename O>
177 struct LSBZeroSetter {
EIGEN_EMPTY_STRUCT_CTORLSBZeroSetter178   EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
179 
180   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
181     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
182     static_assert(
183         bits > 0,
184         "The output type must have fewer mantissa bits than the input type\n");
185     I t = a;
186     LSBZeroSetterHelper(t, bits);
187     return t;
188   }
189 };
190 
191 template <typename I, typename O>
192 struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
193   EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
194 
195   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
196       const std::complex<I>& a) const {
197     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
198     static_assert(
199         bits > 0,
200         "The output type must have fewer mantissa bits than the input type\n");
201     I re = std::real(a);
202     I img = std::imag(a);
203     LSBZeroSetterHelper(re, bits);
204     LSBZeroSetterHelper(img, bits);
205     std::complex<I> toReturn(re, img);
206     return toReturn;
207   }
208 };
209 
210 template <typename I, typename O>
211 struct LSBZeroSetter<std::complex<I>, O> {
212   EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
213   // Sets the 16 LSBits of the float to 0
214   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
215       const std::complex<I>& a) const {
216     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
217     static_assert(
218         bits > 0,
219         "The output type must have fewer mantissa bits than the input type\n");
220     I re = std::real(a);
221     I img = std::imag(a);
222     LSBZeroSetterHelper(re, bits);
223     LSBZeroSetterHelper(img, bits);
224     std::complex<I> toReturn(re, img);
225     return toReturn;
226   }
227 };
228 
229 }  // end namespace functor
230 }  // end namespace tensorflow
231 
232 namespace Eigen {
233 namespace internal {
234 
235 // Eigen can't convert to/from complex numbers, because it is limited to cases
236 // that can be static_casted. But numpy is able to cast to/from complex, which
237 // we want to replicate. So we add specializations for complex here.
238 template <typename From, typename To>
239 struct scalar_cast_op<std::complex<From>, To> {
240   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
241   operator()(const std::complex<From>& a) const {
242     // Replicate numpy behavior of returning just the real part
243     return static_cast<To>(a.real());
244   }
245 };
246 
247 template <typename From, typename To>
248 struct scalar_cast_op<From, std::complex<To>> {
249   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
250       const From& a) const {
251     // Replicate numpy behavior of setting the imaginary part to 0
252     return std::complex<To>(static_cast<To>(a), To(0));
253   }
254 };
255 
256 template <typename From, typename To>
257 struct scalar_cast_op<std::complex<From>, std::complex<To>> {
258   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
259       const std::complex<From>& a) const {
260     return std::complex<To>(static_cast<To>(a.real()),
261                             static_cast<To>(a.imag()));
262   }
263 };
264 
265 template <typename From, typename To>
266 struct functor_traits_complex_impl {
267   enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
268 };
269 
270 template <typename From, typename To>
271 struct functor_traits<scalar_cast_op<std::complex<From>, To>>
272     : functor_traits_complex_impl<std::complex<From>, To> {};
273 template <typename From, typename To>
274 struct functor_traits<scalar_cast_op<From, std::complex<To>>>
275     : functor_traits_complex_impl<From, std::complex<To>> {};
276 // Needed to avoid ambiguous partial specialization
277 template <typename From, typename To>
278 struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
279     : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
280 
281 }  // namespace internal
282 }  // namespace Eigen
283 
284 #endif  // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
285