1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ 19 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) 20 #define EIGEN_USE_GPU 21 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 22 23 #include "tensorflow/core/kernels/quantize_and_dequantize_op.h" 24 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/type_traits.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 32 namespace tensorflow { 33 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 typedef Eigen::GpuDevice GPUDevice; 36 37 // Simulate quantization precision loss in a float tensor by: 38 // 1. Quantize the tensor to fixed point numbers, which should match the target 39 // quantization method when it is used in inference. 40 // 2. Dequantize it back to floating point numbers for the following ops, most 41 // likely matmul. 42 template <typename Device, typename T> 43 class QuantizeAndDequantizeV2Op : public OpKernel { 44 public: QuantizeAndDequantizeV2Op(OpKernelConstruction * ctx)45 explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx) 46 : OpKernel(ctx) { 47 OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); 48 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); 49 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); 50 OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), 51 errors::InvalidArgument("num_bits is out of range: ", num_bits_, 52 " with signed_input_ ", signed_input_)); 53 OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); 54 55 string round_mode_string; 56 OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); 57 OP_REQUIRES( 58 ctx, 59 (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"), 60 errors::InvalidArgument("Round mode string must be " 61 "'HALF_UP' or " 62 "'HALF_TO_EVEN', is '" + 63 round_mode_string + "'")); 64 if (round_mode_string == "HALF_UP") { 65 round_mode_ = ROUND_HALF_UP; 66 } else if (round_mode_string == "HALF_TO_EVEN") { 67 round_mode_ = ROUND_HALF_TO_EVEN; 68 } 69 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_)); 70 } 71 Compute(OpKernelContext * ctx)72 void Compute(OpKernelContext* ctx) override { 73 const Tensor& input = ctx->input(0); 74 OP_REQUIRES( 75 ctx, (axis_ == -1 || axis_ < input.shape().dims()), 76 errors::InvalidArgument("Shape must be at least rank ", axis_ + 1, 77 " but is rank ", input.shape().dims())); 78 const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); 79 Tensor input_min_tensor; 80 Tensor input_max_tensor; 81 Tensor* output = nullptr; 82 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); 83 if (range_given_) { 84 input_min_tensor = ctx->input(1); 85 input_max_tensor = ctx->input(2); 86 if (axis_ == -1) { 87 auto min_val = input_min_tensor.scalar<T>()(); 88 auto max_val = input_max_tensor.scalar<T>()(); 89 OP_REQUIRES(ctx, min_val <= max_val, 90 errors::InvalidArgument("Invalid range: input_min ", 91 min_val, " > input_max ", max_val)); 92 } else { 93 OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth, 94 errors::InvalidArgument( 95 "input_min_tensor has incorrect size, was ", 96 input_min_tensor.dim_size(0), " expected ", depth, 97 " to match dim ", axis_, " of the input ", 98 input_min_tensor.shape())); 99 OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth, 100 errors::InvalidArgument( 101 "input_max_tensor has incorrect size, was ", 102 input_max_tensor.dim_size(0), " expected ", depth, 103 " to match dim ", axis_, " of the input ", 104 input_max_tensor.shape())); 105 } 106 } else { 107 auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth}); 108 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 109 range_shape, &input_min_tensor)); 110 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 111 range_shape, &input_max_tensor)); 112 } 113 114 if (axis_ == -1) { 115 functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f; 116 f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_, 117 range_given_, &input_min_tensor, &input_max_tensor, round_mode_, 118 narrow_range_, output->flat<T>()); 119 } else { 120 functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f; 121 f(ctx->eigen_device<Device>(), 122 input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_, 123 num_bits_, range_given_, &input_min_tensor, &input_max_tensor, 124 round_mode_, narrow_range_, 125 output->template flat_inner_outer_dims<T, 3>(axis_ - 1)); 126 } 127 } 128 129 private: 130 int num_bits_; 131 int axis_; 132 QuantizerRoundMode round_mode_; 133 bool signed_input_; 134 bool range_given_; 135 bool narrow_range_; 136 }; 137 138 // Implementation of QuantizeAndDequantizeV4GradientOp. 139 // When back-propagating the error through a quantized layer, the following 140 // paper gives evidence that clipped-ReLU is better than non-clipped: 141 // "Deep Learning with Low Precision by Half-wave Gaussian Quantization" 142 // http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf 143 template <typename Device, typename T> 144 class QuantizeAndDequantizeV4GradientOp : public OpKernel { 145 public: QuantizeAndDequantizeV4GradientOp(OpKernelConstruction * ctx)146 explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx) 147 : OpKernel::OpKernel(ctx) { 148 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); 149 } 150 Compute(OpKernelContext * ctx)151 void Compute(OpKernelContext* ctx) override { 152 const Tensor& gradient = ctx->input(0); 153 const Tensor& input = ctx->input(1); 154 Tensor* input_backprop = nullptr; 155 OP_REQUIRES_OK(ctx, 156 ctx->allocate_output(0, input.shape(), &input_backprop)); 157 158 OP_REQUIRES( 159 ctx, input.IsSameSize(gradient), 160 errors::InvalidArgument("gradient and input must be the same size")); 161 const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); 162 const Tensor& input_min_tensor = ctx->input(2); 163 const Tensor& input_max_tensor = ctx->input(3); 164 if (axis_ != -1) { 165 OP_REQUIRES( 166 ctx, input_min_tensor.dim_size(0) == depth, 167 errors::InvalidArgument("min has incorrect size, expected ", depth, 168 " was ", input_min_tensor.dim_size(0))); 169 OP_REQUIRES( 170 ctx, input_max_tensor.dim_size(0) == depth, 171 errors::InvalidArgument("max has incorrect size, expected ", depth, 172 " was ", input_max_tensor.dim_size(0))); 173 } 174 175 TensorShape min_max_shape(input_min_tensor.shape()); 176 Tensor* input_min_backprop; 177 OP_REQUIRES_OK(ctx, 178 ctx->allocate_output(1, min_max_shape, &input_min_backprop)); 179 180 Tensor* input_max_backprop; 181 OP_REQUIRES_OK(ctx, 182 ctx->allocate_output(2, min_max_shape, &input_max_backprop)); 183 184 if (axis_ == -1) { 185 functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f; 186 f(ctx->eigen_device<Device>(), gradient.template flat<T>(), 187 input.template flat<T>(), input_min_tensor.scalar<T>(), 188 input_max_tensor.scalar<T>(), input_backprop->template flat<T>(), 189 input_min_backprop->template scalar<T>(), 190 input_max_backprop->template scalar<T>()); 191 } else { 192 functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f; 193 f(ctx->eigen_device<Device>(), 194 gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1), 195 input.template flat_inner_outer_dims<T, 3>(axis_ - 1), 196 &input_min_tensor, &input_max_tensor, 197 input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1), 198 input_min_backprop->template flat<T>(), 199 input_max_backprop->template flat<T>()); 200 } 201 } 202 203 private: 204 int axis_; 205 }; 206 207 // Simulate quantization precision loss in a float tensor by: 208 // 1. Quantize the tensor to fixed point numbers, which should match the target 209 // quantization method when it is used in inference. 210 // 2. Dequantize it back to floating point numbers for the following ops, most 211 // likely matmul. 212 // Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a 213 // tensor. 214 template <typename Device, typename T> 215 class QuantizeAndDequantizeV3Op : public OpKernel { 216 public: QuantizeAndDequantizeV3Op(OpKernelConstruction * ctx)217 explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx) 218 : OpKernel(ctx) { 219 OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); 220 OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); 221 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_)); 222 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); 223 } 224 Compute(OpKernelContext * ctx)225 void Compute(OpKernelContext* ctx) override { 226 const Tensor& input = ctx->input(0); 227 const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); 228 Tensor* output = nullptr; 229 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); 230 231 Tensor num_bits_tensor; 232 num_bits_tensor = ctx->input(3); 233 int num_bits_val = num_bits_tensor.scalar<int32>()(); 234 235 OP_REQUIRES( 236 ctx, num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63), 237 errors::InvalidArgument("num_bits is out of range: ", num_bits_val, 238 " with signed_input_ ", signed_input_)); 239 240 Tensor input_min_tensor; 241 Tensor input_max_tensor; 242 if (range_given_) { 243 input_min_tensor = ctx->input(1); 244 input_max_tensor = ctx->input(2); 245 if (axis_ == -1) { 246 auto min_val = input_min_tensor.scalar<T>()(); 247 auto max_val = input_max_tensor.scalar<T>()(); 248 OP_REQUIRES(ctx, min_val <= max_val, 249 errors::InvalidArgument("Invalid range: input_min ", 250 min_val, " > input_max ", max_val)); 251 } else { 252 OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth, 253 errors::InvalidArgument( 254 "input_min_tensor has incorrect size, was ", 255 input_min_tensor.dim_size(0), " expected ", depth, 256 " to match dim ", axis_, " of the input ", 257 input_min_tensor.shape())); 258 OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth, 259 errors::InvalidArgument( 260 "input_max_tensor has incorrect size, was ", 261 input_max_tensor.dim_size(0), " expected ", depth, 262 " to match dim ", axis_, " of the input ", 263 input_max_tensor.shape())); 264 } 265 } else { 266 auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth}); 267 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 268 range_shape, &input_min_tensor)); 269 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 270 range_shape, &input_max_tensor)); 271 } 272 273 if (axis_ == -1) { 274 functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f; 275 f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, 276 num_bits_val, range_given_, &input_min_tensor, &input_max_tensor, 277 ROUND_HALF_TO_EVEN, narrow_range_, output->flat<T>()); 278 } else { 279 functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f; 280 f(ctx->eigen_device<Device>(), 281 input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_, 282 num_bits_val, range_given_, &input_min_tensor, &input_max_tensor, 283 ROUND_HALF_TO_EVEN, narrow_range_, 284 output->template flat_inner_outer_dims<T, 3>(axis_ - 1)); 285 } 286 } 287 288 private: 289 int axis_; 290 bool signed_input_; 291 bool range_given_; 292 bool narrow_range_; 293 }; 294 295 // DEPRECATED: Use QuantizeAndDequantizeV2Op. 296 template <typename Device, typename T> 297 class QuantizeAndDequantizeOp : public OpKernel { 298 public: QuantizeAndDequantizeOp(OpKernelConstruction * ctx)299 explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 300 OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); 301 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); 302 OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), 303 errors::InvalidArgument("num_bits is out of range: ", num_bits_, 304 " with signed_input_ ", signed_input_)); 305 OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); 306 OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min", &input_min_)); 307 OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max", &input_max_)); 308 if (range_given_) { 309 OP_REQUIRES( 310 ctx, input_min_ <= input_max_, 311 errors::InvalidArgument("Invalid range: input_min ", input_min_, 312 " > input_max ", input_max_)); 313 } 314 } 315 Compute(OpKernelContext * ctx)316 void Compute(OpKernelContext* ctx) override { 317 const Tensor& input = ctx->input(0); 318 319 Tensor* output = nullptr; 320 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); 321 322 // One global scale. 323 Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape()); 324 Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape()); 325 // Initialize the tensors with the values in the Attrs. 326 input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_); 327 input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_); 328 329 functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor; 330 functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, 331 num_bits_, range_given_, &input_min_tensor, &input_max_tensor, 332 ROUND_HALF_TO_EVEN, /*narrow_range=*/false, output->flat<T>()); 333 } 334 335 private: 336 bool signed_input_; 337 int num_bits_; 338 bool range_given_; 339 float input_min_; 340 float input_max_; 341 }; 342 343 // Specializations for CPUDevice. 344 345 namespace functor { 346 template <typename T> 347 struct QuantizeAndDequantizeOneScaleFunctor<CPUDevice, T> { operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleFunctor348 void operator()(const CPUDevice& d, typename TTypes<T>::ConstVec input, 349 const bool signed_input, const int num_bits, 350 const bool range_given, Tensor* input_min_tensor, 351 Tensor* input_max_tensor, QuantizerRoundMode round_mode, 352 bool narrow_range, typename TTypes<T>::Vec out) { 353 QuantizeAndDequantizeOneScaleImpl<CPUDevice, T>::Compute( 354 d, input, signed_input, num_bits, range_given, input_min_tensor, 355 input_max_tensor, round_mode, narrow_range, out); 356 } 357 }; 358 359 template <typename T> 360 struct QuantizeAndDequantizePerChannelFunctor<CPUDevice, T> { operator ()tensorflow::functor::QuantizeAndDequantizePerChannelFunctor361 void operator()(const CPUDevice& d, typename TTypes<T, 3>::ConstTensor input, 362 bool signed_input, int num_bits, bool range_given, 363 Tensor* input_min_tensor, Tensor* input_max_tensor, 364 QuantizerRoundMode round_mode, bool narrow_range, 365 typename TTypes<T, 3>::Tensor out) { 366 QuantizeAndDequantizePerChannelImpl<CPUDevice, T>::Compute( 367 d, input, signed_input, num_bits, range_given, input_min_tensor, 368 input_max_tensor, round_mode, narrow_range, out); 369 } 370 }; 371 372 template <typename T> 373 struct QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice, T> { operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleGradientFunctor374 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat gradient, 375 typename TTypes<T>::ConstFlat input, 376 typename TTypes<T>::ConstScalar input_min_tensor, 377 typename TTypes<T>::ConstScalar input_max_tensor, 378 typename TTypes<T>::Flat input_backprop, 379 typename TTypes<T>::Scalar input_min_backprop, 380 typename TTypes<T>::Scalar input_max_backprop) { 381 QuantizeAndDequantizeOneScaleGradientImpl<CPUDevice, T>::Compute( 382 d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, 383 input_min_backprop, input_max_backprop); 384 } 385 }; 386 387 template <typename T> 388 struct QuantizeAndDequantizePerChannelGradientFunctor<CPUDevice, T> { operator ()tensorflow::functor::QuantizeAndDequantizePerChannelGradientFunctor389 void operator()(const CPUDevice& d, 390 typename TTypes<T, 3>::ConstTensor gradient, 391 typename TTypes<T, 3>::ConstTensor input, 392 const Tensor* input_min_tensor, 393 const Tensor* input_max_tensor, 394 typename TTypes<T, 3>::Tensor input_backprop, 395 typename TTypes<T>::Flat input_min_backprop, 396 typename TTypes<T>::Flat input_max_backprop) { 397 QuantizeAndDequantizePerChannelGradientImpl<CPUDevice, T>::Compute( 398 d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, 399 input_min_backprop, input_max_backprop); 400 } 401 }; 402 403 template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice, 404 float>; 405 template struct functor::QuantizeAndDequantizePerChannelGradientFunctor< 406 CPUDevice, double>; 407 408 } // namespace functor 409 410 #define REGISTER_CPU_KERNEL(T) \ 411 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2") \ 412 .Device(DEVICE_CPU) \ 413 .TypeConstraint<T>("T"), \ 414 QuantizeAndDequantizeV2Op<CPUDevice, T>); \ 415 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \ 416 .Device(DEVICE_CPU) \ 417 .TypeConstraint<T>("T"), \ 418 QuantizeAndDequantizeV3Op<CPUDevice, T>); \ 419 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \ 420 .Device(DEVICE_CPU) \ 421 .TypeConstraint<T>("T"), \ 422 QuantizeAndDequantizeV2Op<CPUDevice, T>); \ 423 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \ 424 .Device(DEVICE_CPU) \ 425 .TypeConstraint<T>("T"), \ 426 QuantizeAndDequantizeV4GradientOp<CPUDevice, T>); \ 427 REGISTER_KERNEL_BUILDER( \ 428 Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 429 QuantizeAndDequantizeOp<CPUDevice, T>); 430 TF_CALL_float(REGISTER_CPU_KERNEL); 431 TF_CALL_double(REGISTER_CPU_KERNEL); 432 #undef REGISTER_CPU_KERNEL 433 434 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ 435 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) 436 #define REGISTER_GPU_KERNEL(T) \ 437 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2") \ 438 .Device(DEVICE_GPU) \ 439 .HostMemory("input_min") \ 440 .HostMemory("input_max") \ 441 .TypeConstraint<T>("T"), \ 442 QuantizeAndDequantizeV2Op<GPUDevice, T>); \ 443 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \ 444 .Device(DEVICE_GPU) \ 445 .HostMemory("input_min") \ 446 .HostMemory("input_max") \ 447 .HostMemory("num_bits") \ 448 .TypeConstraint<T>("T"), \ 449 QuantizeAndDequantizeV3Op<GPUDevice, T>); \ 450 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \ 451 .Device(DEVICE_GPU) \ 452 .HostMemory("input_min") \ 453 .HostMemory("input_max") \ 454 .TypeConstraint<T>("T"), \ 455 QuantizeAndDequantizeV2Op<GPUDevice, T>); \ 456 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \ 457 .Device(DEVICE_GPU) \ 458 .HostMemory("input_min") \ 459 .HostMemory("input_max") \ 460 .TypeConstraint<T>("T"), \ 461 QuantizeAndDequantizeV4GradientOp<GPUDevice, T>); \ 462 REGISTER_KERNEL_BUILDER( \ 463 Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 464 QuantizeAndDequantizeOp<GPUDevice, T>); 465 TF_CALL_float(REGISTER_GPU_KERNEL); 466 TF_CALL_double(REGISTER_GPU_KERNEL); 467 #undef REGISTER_GPU_KERNEL 468 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 469 } // namespace tensorflow 470