1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #ifdef GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif // GOOGLE_CUDA
21
22 #include "tensorflow/core/kernels/fake_quant_ops_functor.h"
23
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/protobuf.h"
28
29 using tensorflow::BinaryElementWiseOp;
30 using tensorflow::DEVICE_CPU;
31 #if GOOGLE_CUDA
32 using tensorflow::DEVICE_GPU;
33 #endif
34 using tensorflow::OpKernel;
35 using tensorflow::OpKernelConstruction;
36 using tensorflow::OpKernelContext;
37 using tensorflow::Tensor;
38 using tensorflow::TensorShape;
39 using tensorflow::TTypes; // NOLINT This is needed in CUDA mode, do not remove.
40 using tensorflow::UnaryElementWiseOp;
41 using tensorflow::errors::InvalidArgument;
42
43 namespace tensorflow {
44
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46
47 namespace {
IsNumBitsValid(int num_bits)48 bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; }
49 } // namespace
50
51 // -----------------------------------------------------------------------------
52 // Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in
53 // core/ops/array_ops.cc.
54 template <typename Device>
55 class FakeQuantWithMinMaxArgsOp
56 : public UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> {
57 public:
58 typedef UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> Base;
FakeQuantWithMinMaxArgsOp(OpKernelConstruction * context)59 explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context)
60 : Base::UnaryElementWiseOp(context) {
61 OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
62 OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
63 OP_REQUIRES(context, min_ < max_,
64 InvalidArgument("min has to be smaller than max, was: ", min_,
65 " >= ", max_));
66 int num_bits;
67 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
68 OP_REQUIRES(
69 context, IsNumBitsValid(num_bits),
70 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
71 bool narrow_range;
72 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
73 quant_min_ = narrow_range ? 1 : 0;
74 quant_max_ = (1 << num_bits) - 1;
75 }
76
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)77 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
78 FakeQuantWithMinMaxArgsFunctor<Device> functor;
79 functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_,
80 quant_min_, quant_max_, output->flat<float>());
81 }
82
83 private:
84 float min_;
85 float max_;
86 int quant_min_;
87 int quant_max_;
88 };
89
90 // Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in
91 // core/ops/array_ops.cc.
92 template <typename Device>
93 class FakeQuantWithMinMaxArgsGradientOp
94 : public BinaryElementWiseOp<float,
95 FakeQuantWithMinMaxArgsGradientOp<Device>> {
96 public:
97 typedef BinaryElementWiseOp<float, FakeQuantWithMinMaxArgsGradientOp<Device>>
98 Base;
FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction * context)99 explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context)
100 : Base::BinaryElementWiseOp(context) {
101 OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
102 OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
103 OP_REQUIRES(context, min_ < max_,
104 InvalidArgument("min has to be smaller than max, was: ", min_,
105 " >= ", max_));
106 int num_bits;
107 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
108 OP_REQUIRES(
109 context, IsNumBitsValid(num_bits),
110 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
111 bool narrow_range;
112 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
113 quant_min_ = narrow_range ? 1 : 0;
114 quant_max_ = (1 << num_bits) - 1;
115 }
116
117 template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & gradient,const Tensor & input,Tensor * output)118 void Operate(OpKernelContext* context, const Tensor& gradient,
119 const Tensor& input, Tensor* output) {
120 OperateNoTemplate(context, gradient, input, output);
121 }
122
OperateNoTemplate(OpKernelContext * context,const Tensor & gradient,const Tensor & input,Tensor * output)123 void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient,
124 const Tensor& input, Tensor* output) {
125 OP_REQUIRES(context, input.IsSameSize(gradient),
126 InvalidArgument("gradient and input must be the same size"));
127 FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
128 functor(context->eigen_device<Device>(), gradient.flat<float>(),
129 input.flat<float>(), min_, max_, quant_min_, quant_max_,
130 output->flat<float>());
131 }
132
133 private:
134 float min_;
135 float max_;
136 int quant_min_;
137 int quant_max_;
138 };
139
140 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU),
141 FakeQuantWithMinMaxArgsOp<CPUDevice>);
142 REGISTER_KERNEL_BUILDER(
143 Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
144 FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
145
146 #if GOOGLE_CUDA
147 typedef Eigen::GpuDevice GPUDevice;
148
149 // Forward declarations for functor specializations for GPU.
150 template <>
151 void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()(
152 const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
153 const float min, const float max, const int quant_min, const int quant_max,
154 typename TTypes<float>::Flat outputs);
155 extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
156 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
157 FakeQuantWithMinMaxArgsOp<GPUDevice>);
158
159 template <>
160 void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
161 const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
162 typename TTypes<float>::ConstFlat inputs, const float min, const float max,
163 const int quant_min, const int quant_max,
164 typename TTypes<float>::Flat backprops);
165 REGISTER_KERNEL_BUILDER(
166 Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
167 FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
168 #endif // GOOGLE_CUDA
169
170 // -----------------------------------------------------------------------------
171 // Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
172 // core/ops/array_ops.cc.
173 template <typename Device>
174 class FakeQuantWithMinMaxVarsOp : public OpKernel {
175 public:
FakeQuantWithMinMaxVarsOp(OpKernelConstruction * context)176 explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context)
177 : OpKernel::OpKernel(context) {
178 int num_bits;
179 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
180 OP_REQUIRES(
181 context, IsNumBitsValid(num_bits),
182 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
183 bool narrow_range;
184 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
185 quant_min_ = narrow_range ? 1 : 0;
186 quant_max_ = (1 << num_bits) - 1;
187 }
188
Compute(OpKernelContext * context)189 void Compute(OpKernelContext* context) override {
190 CHECK_EQ(3, context->num_inputs());
191 const Tensor& input = context->input(0);
192 const Tensor& min = context->input(1);
193 const Tensor& max = context->input(2);
194
195 Tensor* output;
196 OP_REQUIRES_OK(context,
197 context->allocate_output(0, input.shape(), &output));
198
199 FakeQuantWithMinMaxVarsFunctor<Device> functor;
200 functor(context->eigen_device<Device>(), input.flat<float>(),
201 min.scalar<float>(), max.scalar<float>(), quant_min_, quant_max_,
202 output->flat<float>());
203 }
204
205 private:
206 int quant_min_;
207 int quant_max_;
208 };
209
210 // Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in
211 // core/ops/array_ops.cc.
212 template <typename Device>
213 class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
214 public:
FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction * context)215 explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context)
216 : OpKernel::OpKernel(context) {
217 int num_bits;
218 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
219 OP_REQUIRES(
220 context, IsNumBitsValid(num_bits),
221 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
222 bool narrow_range;
223 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
224 quant_min_ = narrow_range ? 1 : 0;
225 quant_max_ = (1 << num_bits) - 1;
226 }
227
Compute(OpKernelContext * context)228 void Compute(OpKernelContext* context) override {
229 CHECK_EQ(4, context->num_inputs());
230 const Tensor& gradient = context->input(0);
231 const Tensor& input = context->input(1);
232 OP_REQUIRES(context, input.IsSameSize(gradient),
233 InvalidArgument("gradient and input must be the same size"));
234 const Tensor& min = context->input(2);
235 const Tensor& max = context->input(3);
236
237 Tensor* grad_wrt_input;
238 OP_REQUIRES_OK(context,
239 context->allocate_output(0, input.shape(), &grad_wrt_input));
240
241 TensorShape scalar_shape;
242 Tensor* grad_wrt_min;
243 OP_REQUIRES_OK(context,
244 context->allocate_output(1, scalar_shape, &grad_wrt_min));
245
246 Tensor* grad_wrt_max;
247 OP_REQUIRES_OK(context,
248 context->allocate_output(2, scalar_shape, &grad_wrt_max));
249
250 FakeQuantWithMinMaxVarsGradientFunctor<Device> functor;
251 functor(context->eigen_device<Device>(), gradient.flat<float>(),
252 input.flat<float>(), min.scalar<float>(), max.scalar<float>(),
253 quant_min_, quant_max_, grad_wrt_input->flat<float>(),
254 grad_wrt_min->scalar<float>(), grad_wrt_max->scalar<float>());
255 }
256
257 private:
258 int quant_min_;
259 int quant_max_;
260 };
261
262 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU),
263 FakeQuantWithMinMaxVarsOp<CPUDevice>);
264 REGISTER_KERNEL_BUILDER(
265 Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
266 FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
267
268 #if GOOGLE_CUDA
269 template <>
270 void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
271 const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
272 typename TTypes<float>::ConstScalar min,
273 typename TTypes<float>::ConstScalar max, const int quant_min,
274 const int quant_max, typename TTypes<float>::Flat output);
275 extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
276 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars")
277 .Device(DEVICE_GPU)
278 .HostMemory("min")
279 .HostMemory("max"),
280 FakeQuantWithMinMaxVarsOp<GPUDevice>);
281
282 template <>
283 void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()(
284 const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
285 typename TTypes<float>::ConstFlat inputs,
286 typename TTypes<float>::ConstScalar min,
287 typename TTypes<float>::ConstScalar max, const int quant_min,
288 const int quant_max, typename TTypes<float>::Flat backprops_wrt_input,
289 typename TTypes<float>::Scalar backprop_wrt_min,
290 typename TTypes<float>::Scalar backprop_wrt_max);
291 extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
292 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient")
293 .Device(DEVICE_GPU)
294 .HostMemory("min")
295 .HostMemory("max"),
296 FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
297 #endif // GOOGLE_CUDA
298
299 // -----------------------------------------------------------------------------
300 // Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
301 // in core/ops/array_ops.cc.
302 template <typename Device>
303 class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
304 public:
FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction * context)305 explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context)
306 : OpKernel::OpKernel(context) {
307 int num_bits;
308 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
309 OP_REQUIRES(
310 context, IsNumBitsValid(num_bits),
311 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
312 bool narrow_range;
313 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
314 quant_min_ = narrow_range ? 1 : 0;
315 quant_max_ = (1 << num_bits) - 1;
316 }
317
Compute(OpKernelContext * context)318 void Compute(OpKernelContext* context) override {
319 CHECK_EQ(3, context->num_inputs());
320 const Tensor& input = context->input(0);
321 const int depth = input.dim_size(input.dims() - 1); // last dimension size.
322 const Tensor& min = context->input(1);
323 OP_REQUIRES(context, min.dim_size(0) == depth,
324 InvalidArgument("min has incorrect size, expected ", depth,
325 " was ", min.dim_size(0)));
326 const Tensor& max = context->input(2);
327 OP_REQUIRES(context, max.dim_size(0) == depth,
328 InvalidArgument("max has incorrect size, expected ", depth,
329 " was ", max.dim_size(0)));
330
331 Tensor* output;
332 OP_REQUIRES_OK(context,
333 context->allocate_output(0, input.shape(), &output));
334
335 FakeQuantWithMinMaxVarsPerChannelFunctor<Device> functor;
336 functor(context->eigen_device<Device>(), input.flat_inner_dims<float, 2>(),
337 min.vec<float>(), max.vec<float>(), quant_min_, quant_max_,
338 output->flat_inner_dims<float, 2>());
339 }
340
341 private:
342 int quant_min_;
343 int quant_max_;
344 };
345
346 // Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its
347 // documentation in core/ops/array_ops.cc.
348 template <typename Device>
349 class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
350 public:
FakeQuantWithMinMaxVarsPerChannelGradientOp(OpKernelConstruction * context)351 explicit FakeQuantWithMinMaxVarsPerChannelGradientOp(
352 OpKernelConstruction* context)
353 : OpKernel::OpKernel(context) {
354 int num_bits;
355 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
356 OP_REQUIRES(
357 context, IsNumBitsValid(num_bits),
358 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
359 bool narrow_range;
360 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
361 quant_min_ = narrow_range ? 1 : 0;
362 quant_max_ = (1 << num_bits) - 1;
363 }
364
Compute(OpKernelContext * context)365 void Compute(OpKernelContext* context) override {
366 CHECK_EQ(4, context->num_inputs());
367 const Tensor& gradient = context->input(0);
368 const Tensor& input = context->input(1);
369 OP_REQUIRES(context, input.IsSameSize(gradient),
370 InvalidArgument("gradient and input must be the same size"));
371 const int depth = input.dim_size(input.dims() - 1); // last dimension size.
372 const Tensor& min = context->input(2);
373 OP_REQUIRES(context, min.dim_size(0) == depth,
374 InvalidArgument("min has incorrect size, expected ", depth,
375 " was ", min.dim_size(0)));
376 const Tensor& max = context->input(3);
377 OP_REQUIRES(context, max.dim_size(0) == depth,
378 InvalidArgument("max has incorrect size, expected ", depth,
379 " was ", max.dim_size(0)));
380
381 Tensor* grad_wrt_input;
382 OP_REQUIRES_OK(context,
383 context->allocate_output(0, input.shape(), &grad_wrt_input));
384
385 TensorShape min_max_shape({input.dim_size(input.dims() - 1)});
386 Tensor* grad_wrt_min;
387 OP_REQUIRES_OK(context,
388 context->allocate_output(1, min_max_shape, &grad_wrt_min));
389
390 Tensor* grad_wrt_max;
391 OP_REQUIRES_OK(context,
392 context->allocate_output(2, min_max_shape, &grad_wrt_max));
393
394 FakeQuantWithMinMaxVarsPerChannelGradientFunctor<Device> functor;
395 functor(
396 context->eigen_device<Device>(), gradient.flat_inner_dims<float, 2>(),
397 input.flat_inner_dims<float, 2>(), min.vec<float>(), max.vec<float>(),
398 quant_min_, quant_max_, grad_wrt_input->flat_inner_dims<float, 2>(),
399 grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
400 }
401
402 private:
403 int quant_min_;
404 int quant_max_;
405 };
406
407 REGISTER_KERNEL_BUILDER(
408 Name("FakeQuantWithMinMaxVarsPerChannel").Device(DEVICE_CPU),
409 FakeQuantWithMinMaxVarsPerChannelOp<CPUDevice>);
410 REGISTER_KERNEL_BUILDER(
411 Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU),
412 FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
413
414 #if GOOGLE_CUDA
415 template <>
416 void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
417 const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs,
418 typename TTypes<float>::ConstFlat min,
419 typename TTypes<float>::ConstFlat max, const int quant_min,
420 const int quant_max, typename TTypes<float>::Matrix outputs);
421 extern template struct FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>;
422
423 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
424 .Device(DEVICE_GPU)
425 .HostMemory("min")
426 .HostMemory("max"),
427 FakeQuantWithMinMaxVarsPerChannelOp<GPUDevice>);
428
429 template <>
430 void FakeQuantWithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
431 const GPUDevice& d, typename TTypes<float>::ConstMatrix gradients,
432 typename TTypes<float>::ConstMatrix inputs,
433 typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
434 const int quant_min, const int quant_max,
435 typename TTypes<float>::Matrix backprops_wrt_input,
436 typename TTypes<float>::Vec backprop_wrt_min,
437 typename TTypes<float>::Vec backprop_wrt_max);
438 extern template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor<
439 GPUDevice>;
440
441 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
442 .Device(DEVICE_GPU)
443 .HostMemory("min")
444 .HostMemory("max"),
445 FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
446 #endif // GOOGLE_CUDA
447
448 } // namespace tensorflow
449