1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
18 
19 #define EIGEN_USE_THREADS
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/kernels/cast_op.h"
23 
24 namespace tensorflow {
25 
26 namespace functor {
27 
28 CAST_FUNCTORS(Eigen::ThreadPoolDevice);
29 
30 
31 }  // namespace functor
32 
33 #define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \
34   FN(arg0, arg1, bool);                      \
35   FN(arg0, arg1, uint8);                     \
36   FN(arg0, arg1, uint16);                    \
37   FN(arg0, arg1, uint32);                    \
38   FN(arg0, arg1, uint64);                    \
39   FN(arg0, arg1, int8);                      \
40   FN(arg0, arg1, int16);                     \
41   FN(arg0, arg1, int32);                     \
42   FN(arg0, arg1, int64);                     \
43   FN(arg0, arg1, float);                     \
44   FN(arg0, arg1, double);                    \
45   FN(arg0, arg1, std::complex<float>);       \
46   FN(arg0, arg1, std::complex<double>)
47 
48 #define CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
49   CURRY_TYPES3_NO_HALF(FN, arg0, arg1)       \
50   FN(arg0, arg1, Eigen::half);
51 
52 #define CURRY_TYPES3(FN, arg0, arg1)   \
53   CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
54   FN(arg0, arg1, bfloat16);
55 
56 #define CAST_CASE(DEVICE, IN, OUT)                                        \
57   if (DataTypeToEnum<OUT>::value == dst_dtype) {                          \
58     return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out,       \
59               bool truncate) {                                            \
60       functor::CastFunctor<DEVICE, OUT, IN> func;                         \
61       func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \
62            truncate);                                                     \
63     };                                                                    \
64   }
65 
66 // The functions below are implemented in the cast_op_impl_*.cc files.
67 CastFunctorType GetCpuCastFromBool(DataType dst_dtype);
68 
69 CastFunctorType GetCpuCastFromUint8(DataType dst_dtype);
70 
71 CastFunctorType GetCpuCastFromUint16(DataType dst_dtype);
72 
73 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
74 
75 CastFunctorType GetCpuCastFromUint32(DataType dst_dtype);
76 
77 CastFunctorType GetCpuCastFromUint64(DataType dst_dtype);
78 
79 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
80 
81 CastFunctorType GetCpuCastFromInt16(DataType dst_dtype);
82 
83 CastFunctorType GetCpuCastFromInt32(DataType dst_dtype);
84 
85 CastFunctorType GetCpuCastFromInt64(DataType dst_dtype);
86 
87 CastFunctorType GetCpuCastFromHalf(DataType dst_dtype);
88 
89 CastFunctorType GetCpuCastFromFloat(DataType dst_dtype);
90 
91 CastFunctorType GetCpuCastFromDouble(DataType dst_dtype);
92 
93 CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype);
94 
95 CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype);
96 
97 CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype);
98 
99 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
100     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
101 // Same, for GPU.
102 CastFunctorType GetGpuCastFromBool(DataType dst_dtype);
103 
104 CastFunctorType GetGpuCastFromUint8(DataType dst_dtype);
105 
106 CastFunctorType GetGpuCastFromUint16(DataType dst_dtype);
107 
108 CastFunctorType GetGpuCastFromInt8(DataType dst_dtype);
109 
110 CastFunctorType GetGpuCastFromUint32(DataType dst_dtype);
111 
112 CastFunctorType GetGpuCastFromUint64(DataType dst_dtype);
113 
114 CastFunctorType GetGpuCastFromInt16(DataType dst_dtype);
115 
116 CastFunctorType GetGpuCastFromInt32(DataType dst_dtype);
117 
118 CastFunctorType GetGpuCastFromInt64(DataType dst_dtype);
119 
120 CastFunctorType GetGpuCastFromHalf(DataType dst_dtype);
121 
122 CastFunctorType GetGpuCastFromFloat(DataType dst_dtype);
123 
124 CastFunctorType GetGpuCastFromDouble(DataType dst_dtype);
125 
126 CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype);
127 
128 CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype);
129 
130 CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype);
131 
132 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
133 
134 
135 }  // namespace tensorflow
136 
137 #endif  // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
138