1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/cast_op.h"
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/work_sharder.h"
30 
31 #include "tensorflow/core/kernels/cast_op_impl.h"
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 
38 #define CURRY_TYPES2(FN, arg0)   \
39   FN(arg0, bool);                \
40   FN(arg0, uint8);               \
41   FN(arg0, uint16);              \
42   FN(arg0, uint32);              \
43   FN(arg0, uint64);              \
44   FN(arg0, int8);                \
45   FN(arg0, int16);               \
46   FN(arg0, int32);               \
47   FN(arg0, int64);               \
48   FN(arg0, Eigen::half);         \
49   FN(arg0, float);               \
50   FN(arg0, double);              \
51   FN(arg0, std::complex<float>); \
52   FN(arg0, std::complex<double>)
53 
CastOpBase(OpKernelConstruction * ctx)54 CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
55   OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
56 
57   OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
58 
59   OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
60 
61   // Quantized data types use the same underlying format as their non quantized
62   // version so we use the non quantized implementation for casting.
63   if (external_dst_dtype_ == DT_QUINT8) {
64     dst_dtype_ = DT_UINT8;
65   } else if (external_dst_dtype_ == DT_QINT8) {
66     dst_dtype_ = DT_INT8;
67   } else if (external_dst_dtype_ == DT_QINT32) {
68     dst_dtype_ = DT_INT32;
69   } else if (external_dst_dtype_ == DT_QINT16) {
70     dst_dtype_ = DT_INT16;
71   } else if (external_dst_dtype_ == DT_QUINT16) {
72     dst_dtype_ = DT_UINT16;
73   } else {
74     dst_dtype_ = external_dst_dtype_;
75   }
76 
77   if (external_src_dtype_ == DT_QUINT8) {
78     src_dtype_ = DT_UINT8;
79   } else if (external_src_dtype_ == DT_QINT8) {
80     src_dtype_ = DT_INT8;
81   } else if (external_src_dtype_ == DT_QINT32) {
82     src_dtype_ = DT_INT32;
83   } else if (external_src_dtype_ == DT_QINT16) {
84     src_dtype_ = DT_INT16;
85   } else if (external_src_dtype_ == DT_QUINT16) {
86     src_dtype_ = DT_UINT16;
87   } else {
88     src_dtype_ = external_src_dtype_;
89   }
90 }
91 
Compute(OpKernelContext * ctx)92 void CastOpBase::Compute(OpKernelContext* ctx) {
93   const Tensor& inp = ctx->input(0);
94   if (work_ == nullptr) {
95     ctx->set_output(0, inp);
96   } else if (external_src_dtype_ != src_dtype_ ||
97              external_dst_dtype_ != dst_dtype_) {
98     Tensor in;
99     // If the type is a quantized type we need to do a bitcast since the
100     // src_dtype_ is different from external_src_type_.
101     OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
102     Tensor* out = nullptr;
103     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
104     out->set_dtype(dst_dtype_);
105     work_(ctx, in, out, use_truncation_);
106     out->set_dtype(external_dst_dtype_);
107   } else {
108     Tensor* out = nullptr;
109     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
110     work_(ctx, inp, out, use_truncation_);
111   }
112 }
113 
Unimplemented()114 Status CastOpBase::Unimplemented() {
115   return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
116                                " to ", DataTypeString(external_dst_dtype_),
117                                " is not supported");
118 }
119 
CpuCastOp(OpKernelConstruction * ctx)120 CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
121   OP_REQUIRES_OK(ctx, Prepare());
122 }
123 
Prepare()124 Status CpuCastOp::Prepare() {
125   if (external_src_dtype_ == external_dst_dtype_) {
126     work_ = nullptr;  // Identity
127     return Status::OK();
128   }
129   if (src_dtype_ == DT_BOOL) {
130     work_ = GetCpuCastFromBool(dst_dtype_);
131   } else if (src_dtype_ == DT_UINT8) {
132     work_ = GetCpuCastFromUint8(dst_dtype_);
133   } else if (src_dtype_ == DT_UINT16) {
134     work_ = GetCpuCastFromUint16(dst_dtype_);
135   } else if (src_dtype_ == DT_UINT32) {
136     work_ = GetCpuCastFromUint32(dst_dtype_);
137   } else if (src_dtype_ == DT_UINT64) {
138     work_ = GetCpuCastFromUint64(dst_dtype_);
139   } else if (src_dtype_ == DT_INT8) {
140     work_ = GetCpuCastFromInt8(dst_dtype_);
141   } else if (src_dtype_ == DT_INT16) {
142     work_ = GetCpuCastFromInt16(dst_dtype_);
143   } else if (src_dtype_ == DT_INT32) {
144     work_ = GetCpuCastFromInt32(dst_dtype_);
145   } else if (src_dtype_ == DT_INT64) {
146     work_ = GetCpuCastFromInt64(dst_dtype_);
147   } else if (src_dtype_ == DT_HALF) {
148     work_ = GetCpuCastFromHalf(dst_dtype_);
149   } else if (src_dtype_ == DT_FLOAT) {
150     work_ = GetCpuCastFromFloat(dst_dtype_);
151   } else if (src_dtype_ == DT_DOUBLE) {
152     work_ = GetCpuCastFromDouble(dst_dtype_);
153   } else if (src_dtype_ == DT_COMPLEX64) {
154     work_ = GetCpuCastFromComplex64(dst_dtype_);
155   } else if (src_dtype_ == DT_COMPLEX128) {
156     work_ = GetCpuCastFromComplex128(dst_dtype_);
157   } else if (src_dtype_ == DT_BFLOAT16) {
158     work_ = GetCpuCastFromBfloat(dst_dtype_);
159   }
160 
161   // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
162   // bottleneck, we could probably implement specialized support for
163   // vectorized versions (not the least based on F16C for Haswell
164   // or newer).
165 
166   return work_ == nullptr ? Unimplemented() : Status::OK();
167 }
168 
169 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
170     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
171 class GpuCastOp : public CastOpBase {
172  public:
GpuCastOp(OpKernelConstruction * ctx)173   explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
174     OP_REQUIRES_OK(ctx, Prepare());
175   }
176 
177  private:
Prepare()178   Status Prepare() {
179     if (external_src_dtype_ == external_dst_dtype_) {
180       work_ = nullptr;  // Identity
181       return Status::OK();
182     }
183     if (src_dtype_ == DT_BOOL) {
184       work_ = GetGpuCastFromBool(dst_dtype_);
185     } else if (src_dtype_ == DT_UINT8) {
186       work_ = GetGpuCastFromUint8(dst_dtype_);
187     } else if (src_dtype_ == DT_UINT16) {
188       work_ = GetGpuCastFromUint16(dst_dtype_);
189     } else if (src_dtype_ == DT_UINT32) {
190       work_ = GetGpuCastFromUint32(dst_dtype_);
191     } else if (src_dtype_ == DT_UINT64) {
192       work_ = GetGpuCastFromUint64(dst_dtype_);
193     } else if (src_dtype_ == DT_INT8) {
194       work_ = GetGpuCastFromInt8(dst_dtype_);
195     } else if (src_dtype_ == DT_INT16) {
196       work_ = GetGpuCastFromInt16(dst_dtype_);
197     } else if (src_dtype_ == DT_INT32) {
198       work_ = GetGpuCastFromInt32(dst_dtype_);
199     } else if (src_dtype_ == DT_INT64) {
200       work_ = GetGpuCastFromInt64(dst_dtype_);
201     } else if (src_dtype_ == DT_HALF) {
202       work_ = GetGpuCastFromHalf(dst_dtype_);
203     } else if (src_dtype_ == DT_FLOAT) {
204       work_ = GetGpuCastFromFloat(dst_dtype_);
205     } else if (src_dtype_ == DT_DOUBLE) {
206       work_ = GetGpuCastFromDouble(dst_dtype_);
207     } else if (src_dtype_ == DT_COMPLEX64) {
208       work_ = GetGpuCastFromComplex64(dst_dtype_);
209     } else if (src_dtype_ == DT_COMPLEX128) {
210       work_ = GetGpuCastFromComplex128(dst_dtype_);
211     } else if (src_dtype_ == DT_BFLOAT16) {
212       work_ = GetGpuCastFromBfloat(dst_dtype_);
213     }
214 
215     return work_ == nullptr ? Unimplemented() : Status::OK();
216   }
217 };
218 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
219 
220 #undef CAST_CASE
221 
222 REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
223 
224 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
225     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
226 #define REGISTER_CAST_GPU(srctype, dsttype)                    \
227   REGISTER_KERNEL_BUILDER(Name("Cast")                         \
228                               .TypeConstraint<srctype>("SrcT") \
229                               .TypeConstraint<dsttype>("DstT") \
230                               .Device(DEVICE_GPU),             \
231                           GpuCastOp)
232 
233 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
234     !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
235 CURRY_TYPES2(REGISTER_CAST_GPU, bool);
236 CURRY_TYPES2(REGISTER_CAST_GPU, int8);
237 CURRY_TYPES2(REGISTER_CAST_GPU, int16);
238 CURRY_TYPES2(REGISTER_CAST_GPU, int32);
239 CURRY_TYPES2(REGISTER_CAST_GPU, int64);
240 CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
241 CURRY_TYPES2(REGISTER_CAST_GPU, float);
242 CURRY_TYPES2(REGISTER_CAST_GPU, double);
243 #else
244 
245 #define CURRY_SUBSET_OF_TYPES(FN, arg0) \
246   FN(arg0, uint8);                      \
247   FN(arg0, uint16);                     \
248   FN(arg0, uint32);                     \
249   FN(arg0, uint64);                     \
250   FN(arg0, std::complex<float>);        \
251   FN(arg0, std::complex<double>)
252 
253 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, bool);
254 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int8);
255 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int16);
256 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int32);
257 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int64);
258 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, Eigen::half);
259 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, float);
260 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, double);
261 
262 #undef CURRY_SUBSET_OF_TYPES
263 
264 #endif
265 
266 CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
267 CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
268 CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
269 CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
270 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
271 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
272 REGISTER_CAST_GPU(float, bfloat16);
273 REGISTER_CAST_GPU(bfloat16, float);
274 
275 #undef REGISTER_CAST_GPU
276 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
277 
278 
279 #undef CURRY_TYPES2
280 
281 // HostCast differs from Cast in that its input and output are in host memory.
282 REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
283 REGISTER_KERNEL_BUILDER(
284     Name("_HostCast").Device(DEVICE_DEFAULT).HostMemory("x").HostMemory("y"),
285     CpuCastOp);
286 }  // end namespace tensorflow
287