1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/cast_op.h"
21
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/work_sharder.h"
30
31 #include "tensorflow/core/kernels/cast_op_impl.h"
32
33 namespace tensorflow {
34
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37
38 #define CURRY_TYPES2(FN, arg0) \
39 FN(arg0, bool); \
40 FN(arg0, uint8); \
41 FN(arg0, uint16); \
42 FN(arg0, uint32); \
43 FN(arg0, uint64); \
44 FN(arg0, int8); \
45 FN(arg0, int16); \
46 FN(arg0, int32); \
47 FN(arg0, int64); \
48 FN(arg0, Eigen::half); \
49 FN(arg0, float); \
50 FN(arg0, double); \
51 FN(arg0, std::complex<float>); \
52 FN(arg0, std::complex<double>)
53
CastOpBase(OpKernelConstruction * ctx)54 CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
55 OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
56
57 OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
58
59 OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
60
61 // Quantized data types use the same underlying format as their non quantized
62 // version so we use the non quantized implementation for casting.
63 if (external_dst_dtype_ == DT_QUINT8) {
64 dst_dtype_ = DT_UINT8;
65 } else if (external_dst_dtype_ == DT_QINT8) {
66 dst_dtype_ = DT_INT8;
67 } else if (external_dst_dtype_ == DT_QINT32) {
68 dst_dtype_ = DT_INT32;
69 } else if (external_dst_dtype_ == DT_QINT16) {
70 dst_dtype_ = DT_INT16;
71 } else if (external_dst_dtype_ == DT_QUINT16) {
72 dst_dtype_ = DT_UINT16;
73 } else {
74 dst_dtype_ = external_dst_dtype_;
75 }
76
77 if (external_src_dtype_ == DT_QUINT8) {
78 src_dtype_ = DT_UINT8;
79 } else if (external_src_dtype_ == DT_QINT8) {
80 src_dtype_ = DT_INT8;
81 } else if (external_src_dtype_ == DT_QINT32) {
82 src_dtype_ = DT_INT32;
83 } else if (external_src_dtype_ == DT_QINT16) {
84 src_dtype_ = DT_INT16;
85 } else if (external_src_dtype_ == DT_QUINT16) {
86 src_dtype_ = DT_UINT16;
87 } else {
88 src_dtype_ = external_src_dtype_;
89 }
90 }
91
Compute(OpKernelContext * ctx)92 void CastOpBase::Compute(OpKernelContext* ctx) {
93 const Tensor& inp = ctx->input(0);
94 if (work_ == nullptr) {
95 ctx->set_output(0, inp);
96 } else if (external_src_dtype_ != src_dtype_ ||
97 external_dst_dtype_ != dst_dtype_) {
98 Tensor in;
99 // If the type is a quantized type we need to do a bitcast since the
100 // src_dtype_ is different from external_src_type_.
101 OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
102 Tensor* out = nullptr;
103 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
104 out->set_dtype(dst_dtype_);
105 work_(ctx, in, out, use_truncation_);
106 out->set_dtype(external_dst_dtype_);
107 } else {
108 Tensor* out = nullptr;
109 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
110 work_(ctx, inp, out, use_truncation_);
111 }
112 }
113
Unimplemented()114 Status CastOpBase::Unimplemented() {
115 return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
116 " to ", DataTypeString(external_dst_dtype_),
117 " is not supported");
118 }
119
CpuCastOp(OpKernelConstruction * ctx)120 CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
121 OP_REQUIRES_OK(ctx, Prepare());
122 }
123
Prepare()124 Status CpuCastOp::Prepare() {
125 if (external_src_dtype_ == external_dst_dtype_) {
126 work_ = nullptr; // Identity
127 return Status::OK();
128 }
129 if (src_dtype_ == DT_BOOL) {
130 work_ = GetCpuCastFromBool(dst_dtype_);
131 } else if (src_dtype_ == DT_UINT8) {
132 work_ = GetCpuCastFromUint8(dst_dtype_);
133 } else if (src_dtype_ == DT_UINT16) {
134 work_ = GetCpuCastFromUint16(dst_dtype_);
135 } else if (src_dtype_ == DT_UINT32) {
136 work_ = GetCpuCastFromUint32(dst_dtype_);
137 } else if (src_dtype_ == DT_UINT64) {
138 work_ = GetCpuCastFromUint64(dst_dtype_);
139 } else if (src_dtype_ == DT_INT8) {
140 work_ = GetCpuCastFromInt8(dst_dtype_);
141 } else if (src_dtype_ == DT_INT16) {
142 work_ = GetCpuCastFromInt16(dst_dtype_);
143 } else if (src_dtype_ == DT_INT32) {
144 work_ = GetCpuCastFromInt32(dst_dtype_);
145 } else if (src_dtype_ == DT_INT64) {
146 work_ = GetCpuCastFromInt64(dst_dtype_);
147 } else if (src_dtype_ == DT_HALF) {
148 work_ = GetCpuCastFromHalf(dst_dtype_);
149 } else if (src_dtype_ == DT_FLOAT) {
150 work_ = GetCpuCastFromFloat(dst_dtype_);
151 } else if (src_dtype_ == DT_DOUBLE) {
152 work_ = GetCpuCastFromDouble(dst_dtype_);
153 } else if (src_dtype_ == DT_COMPLEX64) {
154 work_ = GetCpuCastFromComplex64(dst_dtype_);
155 } else if (src_dtype_ == DT_COMPLEX128) {
156 work_ = GetCpuCastFromComplex128(dst_dtype_);
157 } else if (src_dtype_ == DT_BFLOAT16) {
158 work_ = GetCpuCastFromBfloat(dst_dtype_);
159 }
160
161 // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
162 // bottleneck, we could probably implement specialized support for
163 // vectorized versions (not the least based on F16C for Haswell
164 // or newer).
165
166 return work_ == nullptr ? Unimplemented() : Status::OK();
167 }
168
169 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
170 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
171 class GpuCastOp : public CastOpBase {
172 public:
GpuCastOp(OpKernelConstruction * ctx)173 explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
174 OP_REQUIRES_OK(ctx, Prepare());
175 }
176
177 private:
Prepare()178 Status Prepare() {
179 if (external_src_dtype_ == external_dst_dtype_) {
180 work_ = nullptr; // Identity
181 return Status::OK();
182 }
183 if (src_dtype_ == DT_BOOL) {
184 work_ = GetGpuCastFromBool(dst_dtype_);
185 } else if (src_dtype_ == DT_UINT8) {
186 work_ = GetGpuCastFromUint8(dst_dtype_);
187 } else if (src_dtype_ == DT_UINT16) {
188 work_ = GetGpuCastFromUint16(dst_dtype_);
189 } else if (src_dtype_ == DT_UINT32) {
190 work_ = GetGpuCastFromUint32(dst_dtype_);
191 } else if (src_dtype_ == DT_UINT64) {
192 work_ = GetGpuCastFromUint64(dst_dtype_);
193 } else if (src_dtype_ == DT_INT8) {
194 work_ = GetGpuCastFromInt8(dst_dtype_);
195 } else if (src_dtype_ == DT_INT16) {
196 work_ = GetGpuCastFromInt16(dst_dtype_);
197 } else if (src_dtype_ == DT_INT32) {
198 work_ = GetGpuCastFromInt32(dst_dtype_);
199 } else if (src_dtype_ == DT_INT64) {
200 work_ = GetGpuCastFromInt64(dst_dtype_);
201 } else if (src_dtype_ == DT_HALF) {
202 work_ = GetGpuCastFromHalf(dst_dtype_);
203 } else if (src_dtype_ == DT_FLOAT) {
204 work_ = GetGpuCastFromFloat(dst_dtype_);
205 } else if (src_dtype_ == DT_DOUBLE) {
206 work_ = GetGpuCastFromDouble(dst_dtype_);
207 } else if (src_dtype_ == DT_COMPLEX64) {
208 work_ = GetGpuCastFromComplex64(dst_dtype_);
209 } else if (src_dtype_ == DT_COMPLEX128) {
210 work_ = GetGpuCastFromComplex128(dst_dtype_);
211 } else if (src_dtype_ == DT_BFLOAT16) {
212 work_ = GetGpuCastFromBfloat(dst_dtype_);
213 }
214
215 return work_ == nullptr ? Unimplemented() : Status::OK();
216 }
217 };
218 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
219
220 #undef CAST_CASE
221
222 REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
223
224 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
225 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
226 #define REGISTER_CAST_GPU(srctype, dsttype) \
227 REGISTER_KERNEL_BUILDER(Name("Cast") \
228 .TypeConstraint<srctype>("SrcT") \
229 .TypeConstraint<dsttype>("DstT") \
230 .Device(DEVICE_GPU), \
231 GpuCastOp)
232
233 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
234 !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
235 CURRY_TYPES2(REGISTER_CAST_GPU, bool);
236 CURRY_TYPES2(REGISTER_CAST_GPU, int8);
237 CURRY_TYPES2(REGISTER_CAST_GPU, int16);
238 CURRY_TYPES2(REGISTER_CAST_GPU, int32);
239 CURRY_TYPES2(REGISTER_CAST_GPU, int64);
240 CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
241 CURRY_TYPES2(REGISTER_CAST_GPU, float);
242 CURRY_TYPES2(REGISTER_CAST_GPU, double);
243 #else
244
245 #define CURRY_SUBSET_OF_TYPES(FN, arg0) \
246 FN(arg0, uint8); \
247 FN(arg0, uint16); \
248 FN(arg0, uint32); \
249 FN(arg0, uint64); \
250 FN(arg0, std::complex<float>); \
251 FN(arg0, std::complex<double>)
252
253 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, bool);
254 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int8);
255 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int16);
256 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int32);
257 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int64);
258 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, Eigen::half);
259 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, float);
260 CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, double);
261
262 #undef CURRY_SUBSET_OF_TYPES
263
264 #endif
265
266 CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
267 CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
268 CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
269 CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
270 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
271 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
272 REGISTER_CAST_GPU(float, bfloat16);
273 REGISTER_CAST_GPU(bfloat16, float);
274
275 #undef REGISTER_CAST_GPU
276 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
277
278
279 #undef CURRY_TYPES2
280
281 // HostCast differs from Cast in that its input and output are in host memory.
282 REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
283 REGISTER_KERNEL_BUILDER(
284 Name("_HostCast").Device(DEVICE_DEFAULT).HostMemory("x").HostMemory("y"),
285 CpuCastOp);
286 } // end namespace tensorflow
287