1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
19     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 
23 #include "tensorflow/core/kernels/quantize_and_dequantize_op.h"
24 
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/type_traits.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 
32 namespace tensorflow {
33 
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 typedef Eigen::GpuDevice GPUDevice;
36 
37 // Simulate quantization precision loss in a float tensor by:
38 // 1. Quantize the tensor to fixed point numbers, which should match the target
39 //    quantization method when it is used in inference.
40 // 2. Dequantize it back to floating point numbers for the following ops, most
41 //    likely matmul.
42 template <typename Device, typename T>
43 class QuantizeAndDequantizeV2Op : public OpKernel {
44  public:
QuantizeAndDequantizeV2Op(OpKernelConstruction * ctx)45   explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
46       : OpKernel(ctx) {
47     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
48     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
49     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
50     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
51                 errors::InvalidArgument("num_bits is out of range: ", num_bits_,
52                                         " with signed_input_ ", signed_input_));
53     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
54 
55     string round_mode_string;
56     OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string));
57     OP_REQUIRES(
58         ctx,
59         (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"),
60         errors::InvalidArgument("Round mode string must be "
61                                 "'HALF_UP' or "
62                                 "'HALF_TO_EVEN', is '" +
63                                 round_mode_string + "'"));
64     if (round_mode_string == "HALF_UP") {
65       round_mode_ = ROUND_HALF_UP;
66     } else if (round_mode_string == "HALF_TO_EVEN") {
67       round_mode_ = ROUND_HALF_TO_EVEN;
68     }
69     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
70   }
71 
Compute(OpKernelContext * ctx)72   void Compute(OpKernelContext* ctx) override {
73     const Tensor& input = ctx->input(0);
74     OP_REQUIRES(
75         ctx, (axis_ == -1 || axis_ < input.shape().dims()),
76         errors::InvalidArgument("Shape must be at least rank ", axis_ + 1,
77                                 " but is rank ", input.shape().dims()));
78     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
79     Tensor input_min_tensor;
80     Tensor input_max_tensor;
81     Tensor* output = nullptr;
82     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
83     if (range_given_) {
84       input_min_tensor = ctx->input(1);
85       input_max_tensor = ctx->input(2);
86       if (axis_ == -1) {
87         auto min_val = input_min_tensor.scalar<T>()();
88         auto max_val = input_max_tensor.scalar<T>()();
89         OP_REQUIRES(ctx, min_val <= max_val,
90                     errors::InvalidArgument("Invalid range: input_min ",
91                                             min_val, " > input_max ", max_val));
92       } else {
93         OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth,
94                     errors::InvalidArgument(
95                         "input_min_tensor has incorrect size, was ",
96                         input_min_tensor.dim_size(0), " expected ", depth,
97                         " to match dim ", axis_, " of the input ",
98                         input_min_tensor.shape()));
99         OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth,
100                     errors::InvalidArgument(
101                         "input_max_tensor has incorrect size, was ",
102                         input_max_tensor.dim_size(0), " expected ", depth,
103                         " to match dim ", axis_, " of the input ",
104                         input_max_tensor.shape()));
105       }
106     } else {
107       auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
108       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
109                                              range_shape, &input_min_tensor));
110       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
111                                              range_shape, &input_max_tensor));
112     }
113 
114     if (axis_ == -1) {
115       functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
116       f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_,
117         range_given_, &input_min_tensor, &input_max_tensor, round_mode_,
118         narrow_range_, output->flat<T>());
119     } else {
120       functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
121       f(ctx->eigen_device<Device>(),
122         input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
123         num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
124         round_mode_, narrow_range_,
125         output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
126     }
127   }
128 
129  private:
130   int num_bits_;
131   int axis_;
132   QuantizerRoundMode round_mode_;
133   bool signed_input_;
134   bool range_given_;
135   bool narrow_range_;
136 };
137 
138 // Implementation of QuantizeAndDequantizeV4GradientOp.
139 // When back-propagating the error through a quantized layer, the following
140 // paper gives evidence that clipped-ReLU is better than non-clipped:
141 // "Deep Learning with Low Precision by Half-wave Gaussian Quantization"
142 // http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf
143 template <typename Device, typename T>
144 class QuantizeAndDequantizeV4GradientOp : public OpKernel {
145  public:
QuantizeAndDequantizeV4GradientOp(OpKernelConstruction * ctx)146   explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx)
147       : OpKernel::OpKernel(ctx) {
148     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
149   }
150 
Compute(OpKernelContext * ctx)151   void Compute(OpKernelContext* ctx) override {
152     const Tensor& gradient = ctx->input(0);
153     const Tensor& input = ctx->input(1);
154     Tensor* input_backprop = nullptr;
155     OP_REQUIRES_OK(ctx,
156                    ctx->allocate_output(0, input.shape(), &input_backprop));
157 
158     OP_REQUIRES(
159         ctx, input.IsSameSize(gradient),
160         errors::InvalidArgument("gradient and input must be the same size"));
161     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
162     const Tensor& input_min_tensor = ctx->input(2);
163     const Tensor& input_max_tensor = ctx->input(3);
164     if (axis_ != -1) {
165       OP_REQUIRES(
166           ctx, input_min_tensor.dim_size(0) == depth,
167           errors::InvalidArgument("min has incorrect size, expected ", depth,
168                                   " was ", input_min_tensor.dim_size(0)));
169       OP_REQUIRES(
170           ctx, input_max_tensor.dim_size(0) == depth,
171           errors::InvalidArgument("max has incorrect size, expected ", depth,
172                                   " was ", input_max_tensor.dim_size(0)));
173     }
174 
175     TensorShape min_max_shape(input_min_tensor.shape());
176     Tensor* input_min_backprop;
177     OP_REQUIRES_OK(ctx,
178                    ctx->allocate_output(1, min_max_shape, &input_min_backprop));
179 
180     Tensor* input_max_backprop;
181     OP_REQUIRES_OK(ctx,
182                    ctx->allocate_output(2, min_max_shape, &input_max_backprop));
183 
184     if (axis_ == -1) {
185       functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f;
186       f(ctx->eigen_device<Device>(), gradient.template flat<T>(),
187         input.template flat<T>(), input_min_tensor.scalar<T>(),
188         input_max_tensor.scalar<T>(), input_backprop->template flat<T>(),
189         input_min_backprop->template scalar<T>(),
190         input_max_backprop->template scalar<T>());
191     } else {
192       functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f;
193       f(ctx->eigen_device<Device>(),
194         gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1),
195         input.template flat_inner_outer_dims<T, 3>(axis_ - 1),
196         &input_min_tensor, &input_max_tensor,
197         input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1),
198         input_min_backprop->template flat<T>(),
199         input_max_backprop->template flat<T>());
200     }
201   }
202 
203  private:
204   int axis_;
205 };
206 
207 // Simulate quantization precision loss in a float tensor by:
208 // 1. Quantize the tensor to fixed point numbers, which should match the target
209 //    quantization method when it is used in inference.
210 // 2. Dequantize it back to floating point numbers for the following ops, most
211 //    likely matmul.
212 // Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a
213 // tensor.
214 template <typename Device, typename T>
215 class QuantizeAndDequantizeV3Op : public OpKernel {
216  public:
QuantizeAndDequantizeV3Op(OpKernelConstruction * ctx)217   explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx)
218       : OpKernel(ctx) {
219     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
220     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
221     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
222     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
223   }
224 
Compute(OpKernelContext * ctx)225   void Compute(OpKernelContext* ctx) override {
226     const Tensor& input = ctx->input(0);
227     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
228     Tensor* output = nullptr;
229     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
230 
231     Tensor num_bits_tensor;
232     num_bits_tensor = ctx->input(3);
233     int num_bits_val = num_bits_tensor.scalar<int32>()();
234 
235     OP_REQUIRES(
236         ctx, num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63),
237         errors::InvalidArgument("num_bits is out of range: ", num_bits_val,
238                                 " with signed_input_ ", signed_input_));
239 
240     Tensor input_min_tensor;
241     Tensor input_max_tensor;
242     if (range_given_) {
243       input_min_tensor = ctx->input(1);
244       input_max_tensor = ctx->input(2);
245       if (axis_ == -1) {
246         auto min_val = input_min_tensor.scalar<T>()();
247         auto max_val = input_max_tensor.scalar<T>()();
248         OP_REQUIRES(ctx, min_val <= max_val,
249                     errors::InvalidArgument("Invalid range: input_min ",
250                                             min_val, " > input_max ", max_val));
251       } else {
252         OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth,
253                     errors::InvalidArgument(
254                         "input_min_tensor has incorrect size, was ",
255                         input_min_tensor.dim_size(0), " expected ", depth,
256                         " to match dim ", axis_, " of the input ",
257                         input_min_tensor.shape()));
258         OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth,
259                     errors::InvalidArgument(
260                         "input_max_tensor has incorrect size, was ",
261                         input_max_tensor.dim_size(0), " expected ", depth,
262                         " to match dim ", axis_, " of the input ",
263                         input_max_tensor.shape()));
264       }
265     } else {
266       auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
267       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
268                                              range_shape, &input_min_tensor));
269       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
270                                              range_shape, &input_max_tensor));
271     }
272 
273     if (axis_ == -1) {
274       functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
275       f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
276         num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
277         ROUND_HALF_TO_EVEN, narrow_range_, output->flat<T>());
278     } else {
279       functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
280       f(ctx->eigen_device<Device>(),
281         input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
282         num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
283         ROUND_HALF_TO_EVEN, narrow_range_,
284         output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
285     }
286   }
287 
288  private:
289   int axis_;
290   bool signed_input_;
291   bool range_given_;
292   bool narrow_range_;
293 };
294 
295 // DEPRECATED: Use QuantizeAndDequantizeV2Op.
296 template <typename Device, typename T>
297 class QuantizeAndDequantizeOp : public OpKernel {
298  public:
QuantizeAndDequantizeOp(OpKernelConstruction * ctx)299   explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
300     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
301     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
302     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
303                 errors::InvalidArgument("num_bits is out of range: ", num_bits_,
304                                         " with signed_input_ ", signed_input_));
305     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
306     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min", &input_min_));
307     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max", &input_max_));
308     if (range_given_) {
309       OP_REQUIRES(
310           ctx, input_min_ <= input_max_,
311           errors::InvalidArgument("Invalid range: input_min ", input_min_,
312                                   " > input_max ", input_max_));
313     }
314   }
315 
Compute(OpKernelContext * ctx)316   void Compute(OpKernelContext* ctx) override {
317     const Tensor& input = ctx->input(0);
318 
319     Tensor* output = nullptr;
320     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
321 
322     // One global scale.
323     Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape());
324     Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape());
325     // Initialize the tensors with the values in the Attrs.
326     input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_);
327     input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_);
328 
329     functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor;
330     functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
331             num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
332             ROUND_HALF_TO_EVEN, /*narrow_range=*/false, output->flat<T>());
333   }
334 
335  private:
336   bool signed_input_;
337   int num_bits_;
338   bool range_given_;
339   float input_min_;
340   float input_max_;
341 };
342 
343 // Specializations for CPUDevice.
344 
345 namespace functor {
346 template <typename T>
347 struct QuantizeAndDequantizeOneScaleFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleFunctor348   void operator()(const CPUDevice& d, typename TTypes<T>::ConstVec input,
349                   const bool signed_input, const int num_bits,
350                   const bool range_given, Tensor* input_min_tensor,
351                   Tensor* input_max_tensor, QuantizerRoundMode round_mode,
352                   bool narrow_range, typename TTypes<T>::Vec out) {
353     QuantizeAndDequantizeOneScaleImpl<CPUDevice, T>::Compute(
354         d, input, signed_input, num_bits, range_given, input_min_tensor,
355         input_max_tensor, round_mode, narrow_range, out);
356   }
357 };
358 
359 template <typename T>
360 struct QuantizeAndDequantizePerChannelFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizePerChannelFunctor361   void operator()(const CPUDevice& d, typename TTypes<T, 3>::ConstTensor input,
362                   bool signed_input, int num_bits, bool range_given,
363                   Tensor* input_min_tensor, Tensor* input_max_tensor,
364                   QuantizerRoundMode round_mode, bool narrow_range,
365                   typename TTypes<T, 3>::Tensor out) {
366     QuantizeAndDequantizePerChannelImpl<CPUDevice, T>::Compute(
367         d, input, signed_input, num_bits, range_given, input_min_tensor,
368         input_max_tensor, round_mode, narrow_range, out);
369   }
370 };
371 
372 template <typename T>
373 struct QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleGradientFunctor374   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat gradient,
375                   typename TTypes<T>::ConstFlat input,
376                   typename TTypes<T>::ConstScalar input_min_tensor,
377                   typename TTypes<T>::ConstScalar input_max_tensor,
378                   typename TTypes<T>::Flat input_backprop,
379                   typename TTypes<T>::Scalar input_min_backprop,
380                   typename TTypes<T>::Scalar input_max_backprop) {
381     QuantizeAndDequantizeOneScaleGradientImpl<CPUDevice, T>::Compute(
382         d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
383         input_min_backprop, input_max_backprop);
384   }
385 };
386 
387 template <typename T>
388 struct QuantizeAndDequantizePerChannelGradientFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizePerChannelGradientFunctor389   void operator()(const CPUDevice& d,
390                   typename TTypes<T, 3>::ConstTensor gradient,
391                   typename TTypes<T, 3>::ConstTensor input,
392                   const Tensor* input_min_tensor,
393                   const Tensor* input_max_tensor,
394                   typename TTypes<T, 3>::Tensor input_backprop,
395                   typename TTypes<T>::Flat input_min_backprop,
396                   typename TTypes<T>::Flat input_max_backprop) {
397     QuantizeAndDequantizePerChannelGradientImpl<CPUDevice, T>::Compute(
398         d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
399         input_min_backprop, input_max_backprop);
400   }
401 };
402 
403 template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice,
404                                                                       float>;
405 template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
406     CPUDevice, double>;
407 
408 }  // namespace functor
409 
410 #define REGISTER_CPU_KERNEL(T)                                                 \
411   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
412                               .Device(DEVICE_CPU)                              \
413                               .TypeConstraint<T>("T"),                         \
414                           QuantizeAndDequantizeV2Op<CPUDevice, T>);            \
415   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
416                               .Device(DEVICE_CPU)                              \
417                               .TypeConstraint<T>("T"),                         \
418                           QuantizeAndDequantizeV3Op<CPUDevice, T>);            \
419   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4")                      \
420                               .Device(DEVICE_CPU)                              \
421                               .TypeConstraint<T>("T"),                         \
422                           QuantizeAndDequantizeV2Op<CPUDevice, T>);            \
423   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad")                  \
424                               .Device(DEVICE_CPU)                              \
425                               .TypeConstraint<T>("T"),                         \
426                           QuantizeAndDequantizeV4GradientOp<CPUDevice, T>);    \
427   REGISTER_KERNEL_BUILDER(                                                     \
428       Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
429       QuantizeAndDequantizeOp<CPUDevice, T>);
430 TF_CALL_float(REGISTER_CPU_KERNEL);
431 TF_CALL_double(REGISTER_CPU_KERNEL);
432 #undef REGISTER_CPU_KERNEL
433 
434 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
435     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
436 #define REGISTER_GPU_KERNEL(T)                                                 \
437   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
438                               .Device(DEVICE_GPU)                              \
439                               .HostMemory("input_min")                         \
440                               .HostMemory("input_max")                         \
441                               .TypeConstraint<T>("T"),                         \
442                           QuantizeAndDequantizeV2Op<GPUDevice, T>);            \
443   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
444                               .Device(DEVICE_GPU)                              \
445                               .HostMemory("input_min")                         \
446                               .HostMemory("input_max")                         \
447                               .HostMemory("num_bits")                          \
448                               .TypeConstraint<T>("T"),                         \
449                           QuantizeAndDequantizeV3Op<GPUDevice, T>);            \
450   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4")                      \
451                               .Device(DEVICE_GPU)                              \
452                               .HostMemory("input_min")                         \
453                               .HostMemory("input_max")                         \
454                               .TypeConstraint<T>("T"),                         \
455                           QuantizeAndDequantizeV2Op<GPUDevice, T>);            \
456   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad")                  \
457                               .Device(DEVICE_GPU)                              \
458                               .HostMemory("input_min")                         \
459                               .HostMemory("input_max")                         \
460                               .TypeConstraint<T>("T"),                         \
461                           QuantizeAndDequantizeV4GradientOp<GPUDevice, T>);    \
462   REGISTER_KERNEL_BUILDER(                                                     \
463       Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
464       QuantizeAndDequantizeOp<GPUDevice, T>);
465 TF_CALL_float(REGISTER_GPU_KERNEL);
466 TF_CALL_double(REGISTER_GPU_KERNEL);
467 #undef REGISTER_GPU_KERNEL
468 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
469 }  // namespace tensorflow
470