1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/kernels/cwise_ops.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 enum QuantizerRoundMode {
29   // Round half up: if the fraction of y is exactly 0.5, then
30   // round(y) = y + 0.5
31   // E.g., -5.5 gets rounded to -5, -5.4 goes to -5,
32   // 5.4 goes to 5, and 5.5 goes to 6.
33   ROUND_HALF_UP,
34   // Round half to even: if the fraction of y is exactly 0.5, then round(y) is
35   // the nearest even integer to y.
36   // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes
37   // -24, and -24.5 gets rounded to 24.
38   ROUND_HALF_TO_EVEN,
39 };
40 
41 namespace functor {
42 
43 // TODO(pauldonnelly): 'signed_input' should really be called 'signed_output'.
44 
45 template <typename Device, typename T>
46 struct QuantizeAndDequantizeOneScaleFunctor {
47   void operator()(const Device& d, typename TTypes<T>::ConstVec input,
48                   bool signed_input, int num_bits, bool range_given,
49                   Tensor* input_min_tensor, Tensor* input_max_tensor,
50                   QuantizerRoundMode round_mode, bool narrow_range,
51                   typename TTypes<T>::Vec output);
52 };
53 
54 template <typename Device, typename T>
55 struct QuantizeAndDequantizePerChannelFunctor {
56   void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor input,
57                   bool signed_input, int num_bits, bool range_given,
58                   Tensor* input_min_tensor, Tensor* input_max_tensor,
59                   QuantizerRoundMode round_mode, bool narrow_range,
60                   typename TTypes<T, 3>::Tensor output);
61 };
62 
63 template <typename Device, typename T>
64 struct QuantizeAndDequantizeOneScaleGradientFunctor {
65   void operator()(const Device& d, typename TTypes<T>::ConstFlat gradient,
66                   typename TTypes<T>::ConstFlat input,
67                   typename TTypes<T>::ConstScalar input_min,
68                   typename TTypes<T>::ConstScalar input_max,
69                   typename TTypes<T>::Flat input_backprop,
70                   typename TTypes<T>::Scalar input_min_backprop,
71                   typename TTypes<T>::Scalar input_max_backprop);
72 };
73 
74 template <typename Device, typename T>
75 struct QuantizeAndDequantizePerChannelGradientFunctor {
76   void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor gradient,
77                   typename TTypes<T, 3>::ConstTensor input,
78                   const Tensor* input_min_tensor,
79                   const Tensor* input_max_tensor,
80                   typename TTypes<T, 3>::Tensor input_backprop,
81                   typename TTypes<T>::Flat input_min_backprop,
82                   typename TTypes<T>::Flat input_max_backprop);
83 };
84 
85 // The implementation below runs on both CPU and GPU.
86 template <typename Device, typename T, typename Func,
87           typename Vec = typename TTypes<T>::Vec,
88           typename ConstVec = typename TTypes<T>::ConstVec>
ClampScaleAndRound(const Device & d,ConstVec input,T min_range,T max_range,T scale,T inverse_scale,Func round_func,Vec output)89 void ClampScaleAndRound(const Device& d, ConstVec input, T min_range,
90                         T max_range, T scale, T inverse_scale, Func round_func,
91                         Vec output) {
92   output.device(d) = (input.cwiseMin(max_range).cwiseMax(min_range) * scale)
93                          .unaryExpr(round_func) *
94                      inverse_scale;
95 }
96 
97 // The implementation below runs on both CPU and GPU.
98 template <typename Device, typename T, typename Vec = typename TTypes<T>::Vec,
99           typename ConstVec = typename TTypes<T>::ConstVec>
ClampScaleAndRound(const Device & d,ConstVec input,T min_range,T max_range,T scale,T inverse_scale,QuantizerRoundMode round_mode,Vec output)100 void ClampScaleAndRound(const Device& d, ConstVec input, T min_range,
101                         T max_range, T scale, T inverse_scale,
102                         QuantizerRoundMode round_mode, Vec output) {
103   switch (round_mode) {
104     case ROUND_HALF_TO_EVEN:
105       ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale,
106                          Eigen::internal::scalar_round_half_to_even_op<T>(),
107                          output);
108       break;
109     case ROUND_HALF_UP:
110       ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale,
111                          Eigen::internal::scalar_round_up_op<T>(), output);
112       break;
113   }
114 }
115 
116 // The implementation below runs on both CPU and GPU.
117 template <typename Device, typename T, typename Func,
118           typename Vec = typename TTypes<T>::Vec,
119           typename ConstVec = typename TTypes<T>::ConstVec>
ScaleAndRound(const Device & d,ConstVec input,T scale,T inverse_scale,Func round_func,Vec output)120 void ScaleAndRound(const Device& d, ConstVec input, T scale, T inverse_scale,
121                    Func round_func, Vec output) {
122   output.device(d) = (input * scale).unaryExpr(round_func) * inverse_scale;
123 }
124 
125 // The implementation below runs on both CPU and GPU.
126 template <typename Device, typename T, typename Vec = typename TTypes<T>::Vec,
127           typename ConstVec = typename TTypes<T>::ConstVec>
ScaleAndRound(const Device & d,ConstVec input,T scale,T inverse_scale,QuantizerRoundMode round_mode,Vec output)128 void ScaleAndRound(const Device& d, ConstVec input, T scale, T inverse_scale,
129                    QuantizerRoundMode round_mode, Vec output) {
130   switch (round_mode) {
131     case ROUND_HALF_TO_EVEN:
132       ScaleAndRound(d, input, scale, inverse_scale,
133                     Eigen::internal::scalar_round_half_to_even_op<T>(), output);
134       break;
135     case ROUND_HALF_UP:
136       ScaleAndRound(d, input, scale, inverse_scale,
137                     Eigen::internal::scalar_round_up_op<T>(), output);
138       break;
139   }
140 }
141 
142 template <typename T>
ComputeQuantizationRange(bool signed_input,int num_bits,QuantizerRoundMode round_mode,bool narrow_range,T * min_range,T * max_range,T * scale,T * inverse_scale)143 void ComputeQuantizationRange(bool signed_input, int num_bits,
144                               QuantizerRoundMode round_mode, bool narrow_range,
145                               T* min_range, T* max_range, T* scale,
146                               T* inverse_scale) {
147   // Calculate the range for the simulated integer quantization:
148   // e.g. [-127,127] for signed = true, narrow_range = true, num_bits = 8,
149   // or [-128,127] for signed = true, narrow_range = false, num_bits = 8,
150   // or [0, 255] for signed = false, num_bits = 8.
151   const int64 min_quantized = signed_input ? narrow_range
152                                                  ? -(1ULL << (num_bits - 1)) + 1
153                                                  : -(1ULL << (num_bits - 1))
154                                            : 0;
155   const int64 max_quantized =
156       signed_input ? (1ULL << (num_bits - 1)) - 1 : (1ULL << num_bits) - 1;
157   // Determine the maximum scaling factor that would scale
158   // [min_range, max_range] to not exceed [min_quantized, max_quantized],
159   // while keeping 0 unchanged.
160   const T scale_from_min_side = (min_quantized * *min_range > 0)
161                                     ? min_quantized / *min_range
162                                     : std::numeric_limits<T>::max();
163   const T scale_from_max_side = (max_quantized * *max_range > 0)
164                                     ? max_quantized / *max_range
165                                     : std::numeric_limits<T>::max();
166 
167   // Note: Avoids changing the side of the range that determines scale.
168   if (scale_from_min_side < scale_from_max_side) {
169     *scale = scale_from_min_side;
170     *inverse_scale = *min_range / min_quantized;
171     *max_range = max_quantized * *inverse_scale;
172   } else {
173     *scale = scale_from_max_side;
174     *inverse_scale = *max_range / max_quantized;
175     *min_range = min_quantized * *inverse_scale;
176   }
177 }
178 
179 // The implementation below runs on both CPU and GPU.
180 template <typename Device, typename T>
181 struct QuantizeAndDequantizeOneScaleImpl {
ComputeQuantizeAndDequantizeOneScaleImpl182   static void Compute(const Device& d, typename TTypes<T>::ConstVec input,
183                       bool signed_input, int num_bits, bool range_given,
184                       Tensor* input_min_tensor, Tensor* input_max_tensor,
185                       QuantizerRoundMode round_mode, bool narrow_range,
186                       typename TTypes<T>::Vec output) {
187     T min_range;
188     T max_range;
189     auto input_min = input_min_tensor->scalar<T>();
190     auto input_max = input_max_tensor->scalar<T>();
191     if (!range_given) {
192       input_min.device(d) = input.minimum();
193       input_max.device(d) = input.maximum();
194       d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T));
195       d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T));
196     } else {
197       // Copy the range values from their respective tensors on the host.
198       min_range = input_min_tensor->scalar<T>()();
199       max_range = input_max_tensor->scalar<T>()();
200     }
201 
202     T scale, inverse_scale;
203     ComputeQuantizationRange(signed_input, num_bits, round_mode, narrow_range,
204                              &min_range, &max_range, &scale, &inverse_scale);
205 
206     if (range_given) {
207       // Note: The clamping here is to avoid overflow in the quantized type.
208       // The semantics of the op does not guarantee to clamp to the specified
209       // min_range and max_range - because we may have changed either min_range
210       // or max_range.
211       ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale,
212                          round_mode, output);
213     } else {
214       ScaleAndRound(d, input, scale, inverse_scale, round_mode, output);
215     }
216   }
217 };
218 
219 // The implementation below runs on both CPU and GPU.
220 
221 template <typename Device, typename T>
222 struct QuantizeAndDequantizePerChannelImpl {
ComputeQuantizeAndDequantizePerChannelImpl223   static void Compute(const Device& d, typename TTypes<T, 3>::ConstTensor input,
224                       bool signed_input, int num_bits, bool range_given,
225                       Tensor* input_min_tensor, Tensor* input_max_tensor,
226                       QuantizerRoundMode round_mode, bool narrow_range,
227                       typename TTypes<T, 3>::Tensor output) {
228     using Index = typename tensorflow::TTypes<T>::ConstTensor::Index;
229     int num_channels = input.dimension(1);
230     auto input_min = input_min_tensor->vec<T>();
231     auto input_max = input_max_tensor->vec<T>();
232     std::vector<T> min_range(num_channels);
233     std::vector<T> max_range(num_channels);
234 
235     if (!range_given) {
236 #if !defined(EIGEN_HAS_INDEX_LIST)
237       Eigen::array<int, 2> reduce_dims{{0, 2}};
238 #else
239       Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2> > reduce_dims;
240 #endif
241       input_min.device(d) = input.minimum(reduce_dims);
242       input_max.device(d) = input.maximum(reduce_dims);
243       d.memcpyDeviceToHost(min_range.data(), input_min.data(),
244                            num_channels * sizeof(T));
245       d.memcpyDeviceToHost(max_range.data(), input_max.data(),
246                            num_channels * sizeof(T));
247     } else {
248       // Copy the range values from their respective tensors on the host.
249       std::memcpy(min_range.data(), input_min_tensor->vec<T>().data(),
250                   num_channels * sizeof(T));
251       std::memcpy(max_range.data(), input_max_tensor->vec<T>().data(),
252                   num_channels * sizeof(T));
253     }
254 
255     for (Index i = 0; i < num_channels; ++i) {
256       const auto input_chip = input.template chip<1>(i);
257       auto output_chip = output.template chip<1>(i);
258 
259       T scale, inverse_scale;
260       ComputeQuantizationRange(signed_input, num_bits, round_mode, narrow_range,
261                                &min_range[i], &max_range[i], &scale,
262                                &inverse_scale);
263       if (range_given) {
264         ClampScaleAndRound(d, input_chip, min_range[i], max_range[i], scale,
265                            inverse_scale, round_mode, output_chip);
266       } else {
267         ScaleAndRound(d, input_chip, scale, inverse_scale, round_mode,
268                       output_chip);
269       }
270     }
271   }
272 };
273 
274 template <typename Device, typename T>
275 struct QuantizeAndDequantizeOneScaleGradientImpl {
ComputeQuantizeAndDequantizeOneScaleGradientImpl276   static void Compute(const Device& d, typename TTypes<T>::ConstFlat gradient,
277                       typename TTypes<T>::ConstFlat input,
278                       typename TTypes<T>::ConstScalar input_min,
279                       typename TTypes<T>::ConstScalar input_max,
280                       typename TTypes<T>::Flat input_backprop,
281                       typename TTypes<T>::Scalar input_min_backprop,
282                       typename TTypes<T>::Scalar input_max_backprop) {
283     const T min_val = input_min();
284     const T max_val = input_max();
285     const auto in_range =
286         (input >= min_val && input <= max_val)
287             .select(input.constant(1.0f), input.constant(0.0f));
288     input_backprop.device(d) = gradient * in_range;
289     input_min_backprop.device(d) = input_min_backprop.constant(0.0f);
290     input_max_backprop.device(d) = input_max_backprop.constant(0.0f);
291   }
292 };
293 
294 template <typename Device, typename T>
295 struct QuantizeAndDequantizePerChannelGradientImpl {
ComputeQuantizeAndDequantizePerChannelGradientImpl296   static void Compute(const Device& d,
297                       typename TTypes<T, 3>::ConstTensor gradient,
298                       typename TTypes<T, 3>::ConstTensor input,
299                       const Tensor* input_min_tensor,
300                       const Tensor* input_max_tensor,
301                       typename TTypes<T, 3>::Tensor input_backprop,
302                       typename TTypes<T>::Flat input_min_backprop,
303                       typename TTypes<T>::Flat input_max_backprop) {
304     using Index = typename tensorflow::TTypes<T>::ConstTensor::Index;
305     auto input_min = input_min_tensor->vec<T>();
306     auto input_max = input_max_tensor->vec<T>();
307     int num_channels = input.dimension(1);
308     for (Index i = 0; i < num_channels; ++i) {
309       const auto gradient_chip = gradient.template chip<1>(i);
310       const auto input_chip = input.template chip<1>(i);
311       const T min_val = input_min(i);
312       const T max_val = input_max(i);
313       const auto in_range =
314           (input_chip >= min_val && input_chip <= max_val)
315               .select(input_chip.constant(1.0f), input_chip.constant(0.0f));
316       input_backprop.template chip<1>(i).device(d) = gradient_chip * in_range;
317     }
318     input_min_backprop.device(d) = input_min_backprop.constant(0.0f);
319     input_max_backprop.device(d) = input_max_backprop.constant(0.0f);
320   }
321 };
322 
323 }  // end of namespace functor
324 }  // end of namespace tensorflow
325 
326 #endif  // TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_
327