1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
17 #define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
18 
19 #include <cmath>
20 #define EIGEN_USE_THREADS
21 
22 // This is a set of functions that standardizes how quantized values are
23 // interpreted as float numbers.
24 // All of the current implementations are for reference and have not been
25 // optimized. They should be implementable using fixed point representations
26 // to avoid a dependency on floating-point hardware.
27 
28 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
29 #define QUANTIZATION_UTILS_USE_NEON
30 #include <arm_neon.h>
31 #endif
32 
33 #include <array>
34 
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
37 #include "public/gemmlowp.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 
41 namespace tensorflow {
42 
43 // We have to be able to detect and handle overflows in int32, so this function
44 // uses doubles and int64's to make sure we have enough room.
45 template <class T>
FloatToQuantizedUnclamped(float input,float range_min,float range_max)46 inline int64 FloatToQuantizedUnclamped(float input, float range_min,
47                                        float range_max) {
48   const int64 lowest_quantized =
49       static_cast<double>(Eigen::NumTraits<T>::lowest());
50   if (range_min == range_max) {
51     return lowest_quantized;
52   }
53   const int number_of_bits = sizeof(T) * 8;
54   const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
55   const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
56   const double range = ((range_max - range_min) * range_adjust);
57   const double range_scale = (number_of_steps / range);
58   int64 quantized =
59       (round(input * range_scale) - round(range_min * range_scale));
60   quantized += lowest_quantized;
61   return quantized;
62 }
63 
64 template <>
65 inline int64 FloatToQuantizedUnclamped<float>(float input, float range_min,
66                                               float range_max) {
67   return -1;
68 }
69 
70 // This converts the float into the final quantized type, clamping/saturating
71 // any over or underflows.
72 template <class T>
FloatToQuantized(float input,float range_min,float range_max)73 T FloatToQuantized(float input, float range_min, float range_max) {
74   if (std::is_same<T, float>::value) {
75     // Specialization for float. This is used in reference implementation
76     // for float which is useful to compare performance between float
77     // and quantized type.
78     return input;
79   }
80   int64 quantized = FloatToQuantizedUnclamped<T>(input, range_min, range_max);
81   const int64 lowest_quantized =
82       static_cast<int64>(Eigen::NumTraits<T>::lowest());
83   const int64 highest_quantized =
84       static_cast<int64>(Eigen::NumTraits<T>::highest());
85   quantized = std::max(quantized, lowest_quantized);
86   quantized = std::min(quantized, highest_quantized);
87   return static_cast<T>(static_cast<int32>(quantized));
88 }
89 
90 template <class T>
QuantizedToFloat(T input,float range_min,float range_max)91 float QuantizedToFloat(T input, float range_min, float range_max) {
92   if (std::is_same<T, float>::value) {
93     // Specialization for float. This is used in reference implementation
94     // for float which is useful to compare performance between float
95     // and quantized type.
96     return input;
97   }
98   if (range_min == range_max) {
99     return range_min;
100   }
101   const int number_of_bits = sizeof(T) * 8;
102   const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
103   const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
104   const double range = ((range_max - range_min) * range_adjust);
105   const double range_scale = (range / number_of_steps);
106   const int64 lowest_quantized =
107       static_cast<int64>(Eigen::NumTraits<T>::lowest());
108   const double offset_input = static_cast<double>(input) - lowest_quantized;
109   // For compatibility with DEQUANTIZE_WITH_EIGEN, we should convert
110   // range_scale to a float, otherwise range_min_rounded might be slightly
111   // different.
112   const double range_min_rounded =
113       std::round(range_min / static_cast<float>(range_scale)) *
114       static_cast<float>(range_scale);
115   const double result = range_min_rounded + (offset_input * range_scale);
116   return static_cast<float>(result);
117 }
118 
119 template <class T>
FloatForOneQuantizedLevel(float range_min,float range_max)120 float FloatForOneQuantizedLevel(float range_min, float range_max) {
121   const int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest());
122   const int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest());
123   const float float_for_one_quantized_level =
124       (range_max - range_min) / (highest - lowest);
125   return float_for_one_quantized_level;
126 }
127 
128 template <class T1, class T2, class T3>
QuantizationRangeForMultiplication(float min_a,float max_a,float min_b,float max_b,float * min_c,float * max_c)129 void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
130                                         float max_b, float* min_c,
131                                         float* max_c) {
132   const float a_float_for_one_quant_level =
133       FloatForOneQuantizedLevel<T1>(min_a, max_a);
134   const float b_float_for_one_quant_level =
135       FloatForOneQuantizedLevel<T2>(min_b, max_b);
136 
137   const int64 c_highest = static_cast<int64>(Eigen::NumTraits<T3>::highest());
138   const int64 c_lowest = static_cast<int64>(Eigen::NumTraits<T3>::lowest());
139   const float c_float_for_one_quant_level =
140       a_float_for_one_quant_level * b_float_for_one_quant_level;
141 
142   *min_c = c_float_for_one_quant_level * c_lowest;
143   *max_c = c_float_for_one_quant_level * c_highest;
144 }
145 
146 // input_array is an eigen Tensor.  q2f is a QuantizedToFloatStruct.
147 // This evaluates to an eigen tensor expression, to be used like:
148 // auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f);
149 #define DEQUANTIZE_WITH_EIGEN(input_array, q2f)                         \
150   ((q2f.range_min_rounded - q2f.lowest_quantized() * q2f.range_scale) + \
151    input_array.template cast<float>() * q2f.range_scale)
152 
153 // input_array is an eigen Tensor.  f2q is a FloatToQuantizedStruct.
154 // OutputType is the type of output (e.g. quint8).
155 // This evaluates to an eigen tensor expression, to be used like:
156 // auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T);
157 #define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \
158   ((input_array * f2q.range_scale).round() -              \
159    (f2q.range_min_scaled - f2q.lowest_quantized()))       \
160       .cwiseMax(f2q.lower_bound_float())                  \
161       .cwiseMin(f2q.upper_bound_float())                  \
162       .template cast<int32>()                             \
163       .template cast<OutputType>()
164 
165 // For use with DEQUANTIZE_WITH_EIGEN.
166 template <typename T>
167 struct QuantizedToFloatStruct {
168   static constexpr int number_of_bits = sizeof(T) * 8;
169   static constexpr int64 number_of_steps = static_cast<int64>(1)
170                                            << number_of_bits;
171 
lowest_quantizedQuantizedToFloatStruct172   static float lowest_quantized() {
173     return static_cast<float>(Eigen::NumTraits<T>::lowest());
174   }
175 
QuantizedToFloatStructQuantizedToFloatStruct176   QuantizedToFloatStruct(float range_min, float range_max)
177       : range_min(range_min),
178         range_scale((range_max - range_min) / (number_of_steps - 1.0)),
179         range_min_rounded(range_max == range_min
180                               ? range_min
181                               : std::round(range_min / range_scale) *
182                                     range_scale) {}
183 
184   const float range_min;
185   const float range_scale;
186   const float range_min_rounded;
187 };
188 
189 // For use with QUANTIZE_WITH_EIGEN.
190 template <typename T>
191 struct FloatToQuantizedStruct {
192   static constexpr int number_of_bits = sizeof(T) * 8;
193   static constexpr int64 number_of_steps = static_cast<int64>(1)
194                                            << number_of_bits;
195   static constexpr double range_adjust =
196       (number_of_steps / (number_of_steps - 1.0));
197 
198   // Casting QInt32's lowest or highest to a float gives a float that can't be
199   // cast back to int32 or QInt32.  Instead, use bounds that can be converted
200   // back to int32 without going outside the range of an int32.
lower_bound_floatFloatToQuantizedStruct201   static float lower_bound_float() {
202     return Eigen::numext::maxi(
203         static_cast<float>(Eigen::NumTraits<T>::lowest()), -2.147483648e+09f);
204   }
upper_bound_floatFloatToQuantizedStruct205   static float upper_bound_float() {
206     return Eigen::numext::mini(
207         static_cast<float>(Eigen::NumTraits<T>::highest()), +2.147483520e+09f);
208   }
209 
lowest_quantizedFloatToQuantizedStruct210   static float lowest_quantized() {
211     return static_cast<float>(Eigen::NumTraits<T>::lowest());
212   }
213 
FloatToQuantizedStructFloatToQuantizedStruct214   FloatToQuantizedStruct(float range_min, float range_max)
215       : range_min(range_min),
216         range_scale(range_max == range_min
217                         ? 0.0
218                         : (number_of_steps - 1.0) / (range_max - range_min)),
219         range_min_scaled(std::round(range_min * range_scale)) {}
220 
221   const float range_min;
222   const float range_scale;
223   const float range_min_scaled;
224 };
225 
226 template <class T1, class T2>
RequantizeInNewRange(T1 input,float min_input,float max_input,float min_new,float max_new)227 inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input,
228                                float min_new, float max_new) {
229   const float input_float = QuantizedToFloat<T1>(input, min_input, max_input);
230   return FloatToQuantized<T2>(input_float, min_new, max_new);
231 }
232 
233 template <class T1, class T2>
RequantizeManyInNewRange(const T1 * input,int64 count,float min_input,float max_input,float min_output,float max_output,T2 * output)234 inline void RequantizeManyInNewRange(const T1* input, int64 count,
235                                      float min_input, float max_input,
236                                      float min_output, float max_output,
237                                      T2* output) {
238   for (size_t index = 0; index < count; ++index) {
239     const float input_float =
240         QuantizedToFloat<T1>(input[index], min_input, max_input);
241     output[index] = FloatToQuantized<T2>(input_float, min_output, max_output);
242   }
243 }
244 
245 // Because converting 32-bit accumulated results down to eight bit is a common
246 // case, we have a specialized code path to handle it as efficiently as
247 // possible using only fixed-point math for the inner loop.
RequantizeManyInNewRangeReference(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)248 inline void RequantizeManyInNewRangeReference(const qint32* input, int64 count,
249                                               float min_input, float max_input,
250                                               float min_output,
251                                               float max_output,
252                                               quint8* output) {
253   // Initially we calculate all the constants we need once, before we go into
254   // the inner loop.  If this is updated, also update the Eigen version.
255   const int fp_shift = 16;
256   const float input_range = max_input - min_input;
257   const float output_range = max_output - min_output;
258   const float recip_output_range =
259       output_range == 0.0 ? 0.0 : (255.0 / output_range);
260   const float input_rezero = (min_input + max_input) / 2.0;
261   const int64 range_scale_fp =
262       output_range == 0.0 ? 0.0
263                           : static_cast<int64>(255.0 * (1 << fp_shift) *
264                                                input_range / output_range);
265   const int64 input_offset_fp =
266       static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
267   const int64 output_offset_fp =
268       output_range == 0.0
269           ? 0
270           : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
271                                output_range);
272   const int64 rounding_delta = 1 << (fp_shift - 1);
273 
274   // Inside this loop we just do minimal adds, multiplies, and shifts, in a way
275   // that could be easily adapted for a SIMD implementation. It should also be
276   // possible to perform all the calculations in 32-bit rather than 64, but
277   // that's not been implemented yet.
278   for (tensorflow::int64 index = 0; index < count; ++index) {
279     const int64 input_value = static_cast<int64>(input[index]);
280     const int64 fp_value =
281         ((input_value * range_scale_fp) >> 32) + input_offset_fp;
282     const int64 offset_intermediate = fp_value - output_offset_fp;
283     const int64 round_intermediate = offset_intermediate + rounding_delta;
284     int64 quantized_int64 = round_intermediate >> fp_shift;
285     quantized_int64 = std::max(quantized_int64, int64{0});
286     quantized_int64 = std::min(quantized_int64, int64{255});
287     output[index] = static_cast<quint8>(static_cast<int32>(quantized_int64));
288   }
289 }
290 
291 // Another common case is converting eight bit inputs up to thirty two bits, so
292 // we have specialized fixed-point code to accelerate that. There is also a NEON
293 // version for ARM devices below.
RequantizeManyInNewRange8To32BitReference(const quint8 * input,int64 count,float min_input,float max_input,float min_output,float max_output,qint32 * output)294 inline void RequantizeManyInNewRange8To32BitReference(
295     const quint8* input, int64 count, float min_input, float max_input,
296     float min_output, float max_output, qint32* output) {
297   const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
298   const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
299   const int64 code_0_int64 =
300       FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
301   const int64 code_1_int64 =
302       FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
303   const int32 mult_int32 = code_1_int64 - code_0_int64;
304   const int64 lowest_quantized =
305       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
306   const int64 highest_quantized =
307       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
308   for (int64 i = 0; i < count; ++i) {
309     const int64 input_value = static_cast<int64>(input[i]);
310     int64 output_value = code_0_int64 + (input_value * mult_int32);
311     output_value = std::max(output_value, lowest_quantized);
312     output_value = std::min(output_value, highest_quantized);
313     output[i] = static_cast<int32>(output_value);
314   }
315 }
316 
317 #ifdef QUANTIZATION_UTILS_USE_NEON
318 // Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD
319 // intrinsics for ARM platforms.
RequantizeManyInNewRangeNeon(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)320 inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count,
321                                          float min_input, float max_input,
322                                          float min_output, float max_output,
323                                          quint8* output) {
324   // Initially we calculate all the constants we need once, before we go into
325   // the inner loop.  If this is updated, also update the Eigen version.
326   const int fp_shift = 16;
327 
328   // Calculate range variables in advance.
329   // Input range.
330   const float input_range = max_input - min_input;
331   // Output range.
332   const float output_range = max_output - min_output;
333   // Ratio of output range.
334   const float recip_output_range =
335       output_range == 0.0 ? 0.0 : (255.0 / output_range);
336   // Average of input range as zero position of input.
337   const float input_rezero = (min_input + max_input) / 2.0;
338   // In-out range scale.
339   const int32 range_scale_fp =
340       output_range == 0.0 ? 0.0
341                           : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) *
342                                                input_range / output_range);
343   // Input zero position offset to output.
344   const int32 input_offset_fp =
345       static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift));
346   // Output min offset.
347   const int32 output_offset_fp =
348       output_range == 0.0
349           ? 0
350           : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) /
351                                output_range);
352   const int32 rounding_delta = 1 << (fp_shift - 1);
353 
354   // broadcast range to each lane
355   const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp);
356   const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp);
357   const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp);
358   const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta);
359 
360   int64 index = 0;
361   // Use SIMD to requantize.
362   for (; index < (count - 7); index += 8) {
363     const int32* input_ptr = &(input->value) + index;
364     const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr);
365     const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4);
366     const int32x4_t fp_value_low_32x4 = vaddq_s32(
367         input_offset_fp_32x4,
368         vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4));
369     const int32x4_t fp_value_high_32x4 = vaddq_s32(
370         input_offset_fp_32x4,
371         vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4));
372     const int32x4_t offset_intermediate_low_32x4 =
373         vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4);
374     const int32x4_t offset_intermediate_high_32x4 =
375         vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4);
376     const int32x4_t round_intermediate_low_32x4 =
377         vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4);
378     const int32x4_t round_intermediate_high_32x4 =
379         vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4);
380     const int16x4_t quantized_low_16x4 =
381         vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift));
382     const int16x4_t quantized_high_16x4 =
383         vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift));
384     const uint8x8_t quantized_8x8 =
385         vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4));
386     uint8* output_ptr = &(output->value) + index;
387     vst1_u8(output_ptr, quantized_8x8);
388   }
389 
390   // Requantize remaining elements in array without SIMD.
391   for (; index < count; ++index) {
392     const int32 input_value = static_cast<int32>(input[index]);
393     const int32 fp_value =
394         static_cast<int32>(
395             (static_cast<int32>(input_value >> 16) * (range_scale_fp))) +
396         input_offset_fp;
397     const int32 offset_intermediate = fp_value - output_offset_fp;
398     const int32 round_intermediate = offset_intermediate + rounding_delta;
399     int32 quantized_int32 = round_intermediate >> fp_shift;
400     quantized_int32 = std::max(quantized_int32, 0);
401     quantized_int32 = std::min(quantized_int32, 255);
402     output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32));
403   }
404 }
405 
406 template <>
407 inline void RequantizeManyInNewRange<qint32, quint8>(
408     const qint32* input, int64 count, float min_input, float max_input,
409     float min_output, float max_output, quint8* output) {
410   const float input_range = max_input - min_input;
411   const float output_range = max_output - min_output;
412   if ((input_range / output_range) > 16384.0f) {
413     // Our NEON implementation uses 32-bit math and can't handle very
414     // large ranges, so fall back to the reference implementation. We don't
415     // expect these to be common in models, so this shouldn't be a performance
416     // problem in practice.
417     RequantizeManyInNewRangeReference(input, count, min_input, max_input,
418                                       min_output, max_output, output);
419   } else {
420     RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output,
421                                  max_output, output);
422   }
423 }
424 
425 // NEON accelerated 16bit rounded division by 2^n.
426 template <int POW>
Divide16x8PowRound(const int16x8_t val)427 inline int16x8_t Divide16x8PowRound(const int16x8_t val) {
428   const int16x8_t val_sign = vshrq_n_s16(val, 15);
429   const int16x8_t val_xor = veorq_s16(val, val_sign);
430   const int16x8_t val_pos = vsubq_s16(val_xor, val_sign);
431   const int16x8_t shifted_val_pos = vrshrq_n_s16(val_pos, POW);
432   const int16x8_t shifted_val_pos_xor = veorq_s16(shifted_val_pos, val_sign);
433   const int16x8_t shifted_val = vsubq_s16(shifted_val_pos_xor, val_sign);
434   return shifted_val;
435 }
436 
437 // NEON accelerated 64bit rounded division by 2^n.
438 template <int POW>
Divide64x2PowRound(const int64x2_t val)439 inline int64x2_t Divide64x2PowRound(const int64x2_t val) {
440   const int64x2_t val_sign = vshrq_n_s64(val, 63);
441   const int64x2_t val_xor = veorq_s64(val, val_sign);
442   const int64x2_t val_pos = vsubq_s64(val_xor, val_sign);
443   const int64x2_t shifted_val_pos = vrshrq_n_s64(val_pos, POW);
444   const int64x2_t shifted_val_pos_xor = veorq_s64(shifted_val_pos, val_sign);
445   const int64x2_t shifted_val = vsubq_s64(shifted_val_pos_xor, val_sign);
446   return shifted_val;
447 }
448 
449 // NEON accelerated 16bit division by 2^n.
450 // CAVEAT: The input must be greater than min-int16 to avoid underflow.
451 template <int POW>
Divide16x8Pow(const int16x8_t val)452 inline int16x8_t Divide16x8Pow(const int16x8_t val) {
453   static constexpr int16 FIRST_BIT_VAL = 0x0000000000000001;
454   static const int16x8_t FIRST_BIT = vmovq_n_s16(FIRST_BIT_VAL);
455   const int16x8_t val_sign = vshrq_n_s16(val, 15);
456   const int16x8_t neg_offset = vandq_s16(val_sign, FIRST_BIT);
457   const int16x8_t val_with_offset = vsubq_s16(val, neg_offset);
458   const int16x8_t shifted_wo_offset =
459       vsraq_n_s16(neg_offset, val_with_offset, POW);
460   return shifted_wo_offset;
461 }
462 
463 // NEON accelerated 64bit division by 2^n.
464 // CAVEAT: The input must be greater than min-int64 to avoid underflow.
465 template <int POW>
Divide64x2Pow(const int64x2_t val)466 inline int64x2_t Divide64x2Pow(const int64x2_t val) {
467   static constexpr int64 FIRST_BIT_VAL = 0x0000000000000001;
468   static const int64x2_t FIRST_BIT = vmovq_n_s64(FIRST_BIT_VAL);
469   const int64x2_t val_sign = vshrq_n_s64(val, 63);
470   const int64x2_t neg_offset = vandq_s64(val_sign, FIRST_BIT);
471   const int64x2_t val_with_offset = vsubq_s64(val, neg_offset);
472   const int64x2_t shifted_wo_offset =
473       vsraq_n_s64(neg_offset, val_with_offset, POW);
474   return shifted_wo_offset;
475 }
476 
477 // 32bit x 2 NEON accelerated lerp computation.
478 template <int RESOLUTION>
ComputeLerp32x2(const int32x2_t top_left,const int32x2_t top_right,const int32x2_t bottom_left,const int32x2_t bottom_right,const int32x2_t x_lerp,const int32x2_t y_lerp)479 inline int32x2_t ComputeLerp32x2(const int32x2_t top_left,
480                                  const int32x2_t top_right,
481                                  const int32x2_t bottom_left,
482                                  const int32x2_t bottom_right,
483                                  const int32x2_t x_lerp,
484                                  const int32x2_t y_lerp) {
485   static_assert(RESOLUTION < 31, "RESOLUTION must be less than 31");
486   constexpr int32 RESOLUTION_MULT32 = (1 << RESOLUTION);
487   static const int32x2_t RESOLUTION_MULT32x2 = vmov_n_s32(RESOLUTION_MULT32);
488 
489   const int64x2_t top_left_x_res = vmull_s32(top_left, RESOLUTION_MULT32x2);
490   const int64x2_t bottom_left_x_res =
491       vmull_s32(bottom_left, RESOLUTION_MULT32x2);
492 
493   const int32x2_t top_right_sub_top_left = vsub_s32(top_right, top_left);
494   const int64x2_t top_x_res =
495       vmlal_s32(top_left_x_res, top_right_sub_top_left, x_lerp);
496   const int32x2_t bottom_right_sub_bottom_left =
497       vsub_s32(bottom_right, bottom_left);
498   const int64x2_t bottom_x_res =
499       vmlal_s32(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
500 
501   const int64x2_t bottom_sub_top_x_res = vsubq_s64(bottom_x_res, top_x_res);
502   const int64x2_t bottom_sub_top =
503       Divide64x2Pow<RESOLUTION>(bottom_sub_top_x_res);
504   const int32x2_t bottom_sub_top_32 = vqmovn_s64(bottom_sub_top);
505   const int64x2_t top_add_bottom_sub_top_mul_ylerp_x_res =
506       vmlal_s32(top_x_res, bottom_sub_top_32, y_lerp);
507   const int64x2_t retval =
508       Divide64x2PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
509   const int32x2_t retval32 = vqmovn_s64(retval);
510   return retval32;
511 }
512 
513 // 8bit x 8 NEON accelerated lerp computation.
514 template <int RESOLUTION>
ComputeLerp8x8(const uint8x8_t top_left8x8,const uint8x8_t top_right8x8,const uint8x8_t bottom_left8x8,const uint8x8_t bottom_right8x8,const int16x8_t x_lerp,const int16x8_t y_lerp)515 inline uint8x8_t ComputeLerp8x8(const uint8x8_t top_left8x8,
516                                 const uint8x8_t top_right8x8,
517                                 const uint8x8_t bottom_left8x8,
518                                 const uint8x8_t bottom_right8x8,
519                                 const int16x8_t x_lerp,
520                                 const int16x8_t y_lerp) {
521   static_assert(RESOLUTION < 8, "RESOLUTION must be less than 8");
522   constexpr uint8 RESOLUTION_MULT_VAL = (1 << RESOLUTION);
523   static const uint8x8_t RESOLUTION_MULT = vdup_n_u8(RESOLUTION_MULT_VAL);
524 
525   const int16x8_t top_left_x_res =
526       vreinterpretq_s16_u16(vmull_u8(top_left8x8, RESOLUTION_MULT));
527   const int16x8_t bottom_left_x_res =
528       vreinterpretq_s16_u16(vmull_u8(bottom_left8x8, RESOLUTION_MULT));
529 
530   const int16x8_t top_right_sub_top_left =
531       vreinterpretq_s16_u16(vsubl_u8(top_right8x8, top_left8x8));
532   const int16x8_t top_x_res =
533       vmlaq_s16(top_left_x_res, top_right_sub_top_left, x_lerp);
534 
535   const int16x8_t bottom_right_sub_bottom_left =
536       vreinterpretq_s16_u16(vsubl_u8(bottom_right8x8, bottom_left8x8));
537   const int16x8_t bottom_x_res =
538       vmlaq_s16(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
539 
540   const int16x8_t bottom_sub_top_x_res = vsubq_s16(bottom_x_res, top_x_res);
541   const int16x8_t bottom_sub_top =
542       Divide16x8Pow<RESOLUTION>(bottom_sub_top_x_res);
543   const int16x8_t top_add_bottom_sub_top_mul_ylerp_x_res =
544       vmlaq_s16(top_x_res, bottom_sub_top, y_lerp);
545   const int16x8_t retval16 =
546       Divide16x8PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
547   const uint8x8_t retval = vmovn_u16(vreinterpretq_u16_s16(retval16));
548   return retval;
549 }
550 
551 // Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon
552 // Return std::array instead of pointer to leverage return value optimization
Requantize8x8To32Neon(const uint8 * input_ptr,const int64x2_t input_0_64x2,const int32x2_t input_mult_32x2)553 inline std::array<int32x4_t, 2> Requantize8x8To32Neon(
554     const uint8* input_ptr, const int64x2_t input_0_64x2,
555     const int32x2_t input_mult_32x2) {
556   const uint8x8_t input_value_8x8 = vld1_u8(input_ptr);
557   const int16x8_t input_value_16x8 =
558       vreinterpretq_s16_u16(vmovl_u8(input_value_8x8));
559   const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8);
560   const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8);
561   const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4);
562   const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4);
563   const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4);
564   const int32x2_t input_value_low_high_32x2 =
565       vget_high_s32(input_value_low_32x4);
566   const int32x2_t input_value_high_low_32x2 =
567       vget_low_s32(input_value_high_32x4);
568   const int32x2_t input_value_high_high_32x2 =
569       vget_high_s32(input_value_high_32x4);
570   const int64x2_t mult_result_low_low_64x2 =
571       vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2);
572   const int64x2_t mult_result_low_high_64x2 =
573       vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2);
574   const int64x2_t mult_result_high_low_64x2 =
575       vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2);
576   const int64x2_t mult_result_high_high_64x2 =
577       vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2);
578   const int32x2_t output_value_low_low_32x2 =
579       vqmovn_s64(mult_result_low_low_64x2);
580   const int32x2_t output_value_low_high_32x2 =
581       vqmovn_s64(mult_result_low_high_64x2);
582   const int32x2_t output_value_high_low_32x2 =
583       vqmovn_s64(mult_result_high_low_64x2);
584   const int32x2_t output_value_high_high_32x2 =
585       vqmovn_s64(mult_result_high_high_64x2);
586   const int32x4_t output_value_low_32x4 =
587       vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2);
588   const int32x4_t output_value_high_32x4 =
589       vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2);
590   return std::array<int32x4_t, 2>{
591       {output_value_low_32x4, output_value_high_32x4}};
592 }
593 
594 // Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD
595 // intrinsics for ARM platforms.
596 template <>
597 inline void RequantizeManyInNewRange<quint8, qint32>(
598     const quint8* input, int64 count, float min_input, float max_input,
599     float min_output, float max_output, qint32* output) {
600   // Pre-calculate zero position and multiplier.
601   // Calculate 0 and 1 value in float.
602   const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
603   const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
604 
605   // Cast 0 and 1 value in int64.
606   const int64 code_0_int64 =
607       FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
608   const int64 code_1_int64 =
609       FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
610 
611   // Calculate multiplier.
612   const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64);
613 
614   // Broadcast 0 position and multiplier to lanes
615   const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64);
616   const int32x2_t mult_32x2 = vmov_n_s32(mult_int32);
617 
618   int64 i = 0;
619 
620   // Use SIMD to requantize array.
621   for (; i < (count - 7); i += 8) {
622     const uint8* input_ptr = &(input->value) + i;
623     int32* output_ptr = &(output->value) + i;
624     const std::array<int32x4_t, 2> output_value =
625         Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2);
626     vst1q_s32(output_ptr + 0, output_value[0]);
627     vst1q_s32(output_ptr + 4, output_value[1]);
628   }
629 
630   // Requantize remaining elements in array without SIMD.
631   const int64 lowest_quantized =
632       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
633   const int64 highest_quantized =
634       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
635 
636   for (; i < count; ++i) {
637     const int64 input_value = static_cast<int64>(input[i]);
638     int64 output_value = code_0_int64 + (input_value * mult_int32);
639     output_value = std::max(output_value, lowest_quantized);
640     output_value = std::min(output_value, highest_quantized);
641     output[i] = static_cast<int32>(output_value);
642   }
643 }
644 
645 #else
646 
647 // If SIMD implementations aren't available, then use these default reference
648 // versions.
649 template <>
650 inline void RequantizeManyInNewRange<qint32, quint8>(
651     const qint32* input, int64 count, float min_input, float max_input,
652     float min_output, float max_output, quint8* output) {
653   RequantizeManyInNewRangeReference(input, count, min_input, max_input,
654                                     min_output, max_output, output);
655 }
656 
657 template <>
658 inline void RequantizeManyInNewRange<quint8, qint32>(
659     const quint8* input, int64 count, float min_input, float max_input,
660     float min_output, float max_output, qint32* output) {
661   RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input,
662                                             min_output, max_output, output);
663 }
664 
665 #endif
666 
667 template <int shift>
668 struct int64_right_shift_op {
EIGEN_EMPTY_STRUCT_CTORint64_right_shift_op669   EIGEN_EMPTY_STRUCT_CTOR(int64_right_shift_op)
670   EIGEN_DEVICE_FUNC
671   EIGEN_STRONG_INLINE const int64 operator()(const int64& a) const {
672     return a >> shift;
673   }
674 };
675 
676 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
677 template <class T1, class T2>
RequantizeManyInNewRangeUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min_input,float max_input,float min_output,float max_output,Tensor * output)678 inline void RequantizeManyInNewRangeUsingEigen(
679     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
680     float max_input, float min_output, float max_output, Tensor* output) {
681   auto input_array = input.flat<T1>();
682   QuantizedToFloatStruct<T1> q2f(min_input, max_input);
683   auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
684   FloatToQuantizedStruct<T2> f2q(min_output, max_output);
685   auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2);
686 
687   output->flat<T2>().device(device) = input_requantized;
688 }
689 
690 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
691 //
692 // Because converting 32-bit accumulated results down to eight bit is a common
693 // case, we have a specialized code path to handle it as efficiently as
694 // possible using only fixed-point math for the inner loop.
695 template <>
696 inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
697     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
698     float max_input, float min_output, float max_output, Tensor* output) {
699   // Initially we calculate all the constants we need once, before we go into
700   // the inner loop.  If this is updated, also update the non-Eigen version.
701   const int fp_shift = 16;
702   const float input_range = max_input - min_input;
703   const float output_range = max_output - min_output;
704   const float recip_output_range =
705       output_range == 0.0 ? 0.0 : (255.0 / output_range);
706   const float input_rezero = (min_input + max_input) / 2.0;
707   const int64 range_scale_fp =
708       output_range == 0.0 ? 0.0
709                           : static_cast<int64>(255.0 * (1 << fp_shift) *
710                                                input_range / output_range);
711   const int64 input_offset_fp =
712       static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
713   const int64 output_offset_fp =
714       output_range == 0.0
715           ? 0
716           : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
717                                output_range);
718   const int64 rounding_delta = 1 << (fp_shift - 1);
719 
720   // Inside this eigen expression we just do minimal adds, multiplies, and
721   // shifts. It should be possible to perform all the calculations in 32-bit
722   // rather than 64, but that's not been implemented yet.
723   auto input_array = input.flat<qint32>();
724   auto fp_value = ((input_array.template cast<int64>() * range_scale_fp)
725                        .unaryExpr(int64_right_shift_op<32>())) +
726                   (input_offset_fp - output_offset_fp + rounding_delta);
727   auto intermediate = fp_value.unaryExpr(int64_right_shift_op<fp_shift>());
728   auto input_requantized = intermediate.cwiseMax(int64{0})
729                                .cwiseMin(int64{255})
730                                .template cast<int32>()
731                                .template cast<quint8>();
732   output->flat<quint8>().device(device) = input_requantized;
733 }
734 
735 // REQUIRES: 'result->NumElements() == input.NumElements()'
736 template <class T>
FloatTensorToQuantizedInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)737 void FloatTensorToQuantizedInPlaceUsingEigen(
738     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
739     float max, Tensor* result) {
740   DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
741   auto flat_input = input.flat<float>();
742   auto flat_result = result->flat<T>();
743   DCHECK_EQ(flat_input.size(), flat_result.size());
744 
745   FloatToQuantizedStruct<T> f2q(min, max);
746   flat_result.device(device) = QUANTIZE_WITH_EIGEN(flat_input, f2q, T);
747 }
748 
749 template <class T>
FloatTensorToQuantizedInPlace(const Tensor & input,float min,float max,Tensor * result)750 void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max,
751                                    Tensor* result) {
752   DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
753   auto flat_input = input.flat<float>();
754   auto flat_result = result->flat<T>();
755   const int data_size = flat_input.size();
756   DCHECK(data_size == flat_result.size());
757   for (int i = 0; i < data_size; ++i) {
758     flat_result(i) = FloatToQuantized<T>(flat_input(i), min, max);
759   }
760 }
761 
762 template <class T>
FloatTensorToQuantized(const Tensor & input,float min,float max)763 Tensor FloatTensorToQuantized(const Tensor& input, float min, float max) {
764   Tensor result(DataTypeToEnum<T>::v(), input.shape());
765   FloatTensorToQuantizedInPlace<T>(input, min, max, &result);
766   return result;
767 }
768 
769 // REQUIRES: 'result->NumElements() == input.NumElements()'
770 template <class T>
QuantizedTensorToFloatInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)771 void QuantizedTensorToFloatInPlaceUsingEigen(
772     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
773     float max, Tensor* result) {
774   DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
775   auto flat_input = input.flat<T>();
776   auto flat_result = result->flat<float>();
777   const int data_size = flat_input.size();
778   DCHECK(data_size == flat_result.size());
779 
780   QuantizedToFloatStruct<T> q2f(min, max);
781   flat_result.device(device) = DEQUANTIZE_WITH_EIGEN(flat_input, q2f);
782 }
783 
784 // REQUIRES: 'result->NumElements() == input.NumElements()'
785 template <class T>
QuantizedTensorToFloatInPlace(const Tensor & input,float min,float max,Tensor * result)786 void QuantizedTensorToFloatInPlace(const Tensor& input, float min, float max,
787                                    Tensor* result) {
788   DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
789   auto flat_input = input.flat<T>();
790   auto flat_result = result->flat<float>();
791   const int data_size = flat_input.size();
792   DCHECK(data_size == flat_result.size());
793   for (int i = 0; i < data_size; ++i) {
794     flat_result(i) = QuantizedToFloat<T>(flat_input(i), min, max);
795   }
796 }
797 
798 template <class T>
QuantizedTensorToFloat(const Tensor & input,float min,float max)799 Tensor QuantizedTensorToFloat(const Tensor& input, float min, float max) {
800   Tensor result(DT_FLOAT, input.shape());
801   QuantizedTensorToFloatInPlace<T>(input, min, max, &result);
802   return result;
803 }
804 
805 void GetOutputMinAndMaxForQuantizedAdd(float input_min, float input_max,
806                                        float smaller_input_min,
807                                        float smaller_input_max,
808                                        float* output_min, float* output_max);
809 
810 // Add <input> and <smaller_input>.  If <smaller_input> has fewer elements than
811 // <input>, then it is broadcast onto <input>.
812 template <typename T1, typename T2, typename T3>
QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)813 void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device,
814                             const Tensor& input, float input_min,
815                             float input_max, const Tensor& smaller_input,
816                             float smaller_input_min, float smaller_input_max,
817                             Tensor* output, float* output_min,
818                             float* output_max) {
819   const auto& input_flat = input.flat<T1>();
820   const auto& smaller_input_flat = smaller_input.flat<T2>();
821   auto output_flat = output->flat<T3>();
822 
823   GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
824                                     smaller_input_max, output_min, output_max);
825   // To do addition properly, we need to compensate for a possibly unbalanced
826   // zero point in the total representation. The quantized value that
827   // represents the real number zero needs to be subtracted before addition to
828   // make sure that the identity of zero + zero = zero holds.
829   const T3 zero_in_total_space =
830       FloatToQuantized<T3>(0.0f, *output_min, *output_max);
831 
832   const int64 input_element_count = input.NumElements();
833   const int64 smaller_input_element_count = smaller_input.NumElements();
834 
835   QuantizedToFloatStruct<T1> input_q2f(input_min, input_max);
836   QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min,
837                                                smaller_input_max);
838   FloatToQuantizedStruct<T3> f2q(*output_min, *output_max);
839 
840   auto smaller_input_float =
841       DEQUANTIZE_WITH_EIGEN(smaller_input_flat, smaller_input_q2f);
842   auto smaller_input_in_total_space =
843       QUANTIZE_WITH_EIGEN(smaller_input_float, f2q, T3);
844 
845   auto input_float = DEQUANTIZE_WITH_EIGEN(input_flat, input_q2f);
846   auto input_in_total_space = QUANTIZE_WITH_EIGEN(input_float, f2q, T3);
847 
848   Eigen::array<Eigen::DenseIndex, 1> bcast;
849   bcast[0] = input_element_count / smaller_input_element_count;
850   output_flat.device(device) =
851       input_in_total_space +
852       (smaller_input_in_total_space.broadcast(bcast) + zero_in_total_space);
853 }
854 
855 // This is a reference implementation of the bias addition for quantized
856 // buffers, designed to provide a clear specification for the result we
857 // want. We'll want to specialize this for particular hardware, and
858 // probably even fuse it with matrix multiplications in a lot of cases. It's
859 // important to show the clamping behavior we want in particular.
860 template <typename T1, typename T2, typename T3>
QuantizedAdd(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)861 void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input,
862                   float input_min, float input_max, const Tensor& smaller_input,
863                   float smaller_input_min, float smaller_input_max,
864                   Tensor* output, float* output_min, float* output_max) {
865   const auto& input_flat = input.flat<T1>();
866   const auto& smaller_input_flat = smaller_input.flat<T2>();
867   auto output_flat = output->flat<T3>();
868 
869   GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
870                                     smaller_input_max, output_min, output_max);
871   // To do addition properly, we need to compensate for a possibly unbalanced
872   // zero point in the total representation. The quantized value that
873   // represents the real number zero needs to be subtracted before addition to
874   // make sure that the identity of zero + zero = zero holds.
875   const T3 zero_in_total_space =
876       FloatToQuantized<T3>(0.0f, *output_min, *output_max);
877 
878   const int64 input_element_count = input.NumElements();
879   const int64 smaller_input_element_count = smaller_input.NumElements();
880 
881   float total_min = *output_min;
882   float total_max = *output_max;
883   const size_t how_many_iterations =
884       (input_element_count / smaller_input_element_count);
885   for (size_t iteration = 0; iteration < how_many_iterations; ++iteration) {
886     const size_t offset = iteration * smaller_input_element_count;
887     for (int c = 0; c < smaller_input_element_count; ++c) {
888       const int index = (offset + c);
889       // The two numbers we're going to add can each be in very different
890       // ranges (e.g. the quantized value '127' may represent very different
891       // real numbers in both) so we need to convert them to a common range
892       // before we sum them.
893       const T1 input_value = input_flat(index);
894       const T3 input_in_total_space = RequantizeInNewRange<T1, T3>(
895           input_value, input_min, input_max, total_min, total_max);
896       const T2 smaller_input_value = smaller_input_flat(c);
897       const T3 smaller_input_in_total_space =
898           RequantizeInNewRange<T2, T3>(smaller_input_value, smaller_input_min,
899                                        smaller_input_max, total_min, total_max);
900       const T3 total_pre = input_in_total_space + smaller_input_in_total_space;
901       // As noted above, we need to compensate for the offset of the actual
902       // zero point in the space we're operating in.
903       const T3 total = total_pre + zero_in_total_space;
904       output_flat(index) = total;
905     }
906   }
907 }
908 
909 // See gemmlowp/internal/multi_thread_gemm.h for the semantics of Execute.
910 class TensorflowGemmlowpWorkersPool {
911  public:
TensorflowGemmlowpWorkersPool(thread::ThreadPool * workers)912   TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers)
913       : workers_(workers) {}
914 
~TensorflowGemmlowpWorkersPool()915   ~TensorflowGemmlowpWorkersPool() {
916     // This workaround ensures that all worker tasks have exited methods in the
917     // BlockingCounter. Without this, there is a race where the context is torn
918     // down while the counter is in use.
919     counter_to_decrement_when_ready_.Reset(0);
920   }
921 
Execute(const std::vector<gemmlowp::Task * > & tasks)922   void Execute(const std::vector<gemmlowp::Task*>& tasks) {
923     assert(!tasks.empty());
924     assert(workers_ != nullptr);
925     counter_to_decrement_when_ready_.Reset(tasks.size());
926     for (gemmlowp::Task* task : tasks) {
927       workers_->Schedule([this, task]() {
928         // TODO(cwhipkey): get a local_allocator from a thread local storage.
929         gemmlowp::Allocator local_allocator;
930         CHECK(task != nullptr);
931         task->local_allocator = &local_allocator;
932         task->Run();
933         counter_to_decrement_when_ready_.DecrementCount();
934       });
935     }
936     counter_to_decrement_when_ready_.Wait();
937     for (gemmlowp::Task* task : tasks) {
938       delete task;
939     }
940   }
941 
942  private:
943   thread::ThreadPool* const workers_;
944 
945   // The BlockingCounter used to wait for the workers.
946   gemmlowp::BlockingCounter counter_to_decrement_when_ready_;
947 
948   TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool);
949 };
950 
951 class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase {
952  public:
TensorflowGemmContext(int num_threads,thread::ThreadPool * workers)953   TensorflowGemmContext(int num_threads, thread::ThreadPool* workers)
954       : workers_pool_(workers) {
955     set_max_num_threads(num_threads);
956   }
957 
workers_pool()958   TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; }
959 
960  private:
961   TensorflowGemmlowpWorkersPool workers_pool_;
962 
963   TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext);
964 };
965 
966 }  // namespace tensorflow
967 
968 #endif  // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
969