1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 18 19 #define EIGEN_USE_THREADS 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/kernels/cast_op.h" 23 24 namespace tensorflow { 25 26 namespace functor { 27 28 CAST_FUNCTORS(Eigen::ThreadPoolDevice); 29 30 31 } // namespace functor 32 33 #define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ 34 FN(arg0, arg1, bool); \ 35 FN(arg0, arg1, uint8); \ 36 FN(arg0, arg1, uint16); \ 37 FN(arg0, arg1, uint32); \ 38 FN(arg0, arg1, uint64); \ 39 FN(arg0, arg1, int8); \ 40 FN(arg0, arg1, int16); \ 41 FN(arg0, arg1, int32); \ 42 FN(arg0, arg1, int64); \ 43 FN(arg0, arg1, float); \ 44 FN(arg0, arg1, double); \ 45 FN(arg0, arg1, std::complex<float>); \ 46 FN(arg0, arg1, std::complex<double>) 47 48 #define CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ 49 CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ 50 FN(arg0, arg1, Eigen::half); 51 52 #define CURRY_TYPES3(FN, arg0, arg1) \ 53 CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ 54 FN(arg0, arg1, bfloat16); 55 56 #define CAST_CASE(DEVICE, IN, OUT) \ 57 if (DataTypeToEnum<OUT>::value == dst_dtype) { \ 58 return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \ 59 bool truncate) { \ 60 functor::CastFunctor<DEVICE, OUT, IN> func; \ 61 func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \ 62 truncate); \ 63 }; \ 64 } 65 66 // The functions below are implemented in the cast_op_impl_*.cc files. 67 CastFunctorType GetCpuCastFromBool(DataType dst_dtype); 68 69 CastFunctorType GetCpuCastFromUint8(DataType dst_dtype); 70 71 CastFunctorType GetCpuCastFromUint16(DataType dst_dtype); 72 73 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); 74 75 CastFunctorType GetCpuCastFromUint32(DataType dst_dtype); 76 77 CastFunctorType GetCpuCastFromUint64(DataType dst_dtype); 78 79 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); 80 81 CastFunctorType GetCpuCastFromInt16(DataType dst_dtype); 82 83 CastFunctorType GetCpuCastFromInt32(DataType dst_dtype); 84 85 CastFunctorType GetCpuCastFromInt64(DataType dst_dtype); 86 87 CastFunctorType GetCpuCastFromHalf(DataType dst_dtype); 88 89 CastFunctorType GetCpuCastFromFloat(DataType dst_dtype); 90 91 CastFunctorType GetCpuCastFromDouble(DataType dst_dtype); 92 93 CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype); 94 95 CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype); 96 97 CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype); 98 99 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ 100 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) 101 // Same, for GPU. 102 CastFunctorType GetGpuCastFromBool(DataType dst_dtype); 103 104 CastFunctorType GetGpuCastFromUint8(DataType dst_dtype); 105 106 CastFunctorType GetGpuCastFromUint16(DataType dst_dtype); 107 108 CastFunctorType GetGpuCastFromInt8(DataType dst_dtype); 109 110 CastFunctorType GetGpuCastFromUint32(DataType dst_dtype); 111 112 CastFunctorType GetGpuCastFromUint64(DataType dst_dtype); 113 114 CastFunctorType GetGpuCastFromInt16(DataType dst_dtype); 115 116 CastFunctorType GetGpuCastFromInt32(DataType dst_dtype); 117 118 CastFunctorType GetGpuCastFromInt64(DataType dst_dtype); 119 120 CastFunctorType GetGpuCastFromHalf(DataType dst_dtype); 121 122 CastFunctorType GetGpuCastFromFloat(DataType dst_dtype); 123 124 CastFunctorType GetGpuCastFromDouble(DataType dst_dtype); 125 126 CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype); 127 128 CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype); 129 130 CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype); 131 132 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 133 134 135 } // namespace tensorflow 136 137 #endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 138