1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements a quantized version of the resize bilinear op.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
21 #define USE_NEON
22 #define QUANTIZED_RESIZE_BILINEAR_USE_NEON
23 #include <arm_neon.h>
24 #endif
25 
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/image_resizer_state.h"
29 #include "tensorflow/core/kernels/quantization_utils.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 static constexpr bool USE_REFERENCE = false;
35 
36 namespace {
37 // Compute the interpolation indices only once.
38 template <typename T_SCALE>
39 struct InterpolationCache {
40   std::vector<int64> lower;  // Lower source index used in the interpolation
41   std::vector<int64> upper;  // Upper source index used in the interpolation
42   // 1-D linear iterpolation scale (see:
43   // https://en.wikipedia.org/wiki/Bilinear_interpolation)
44   std::vector<float> lerp;
45   std::vector<T_SCALE> ilerp;
46 };
47 
48 template <typename T_SCALE, typename Scaler>
ComputeInterpolationWeights(const int64 out_size,const int64 in_size,const float scale,const int resolution,InterpolationCache<T_SCALE> * interpolation)49 inline void ComputeInterpolationWeights(
50     const int64 out_size, const int64 in_size, const float scale,
51     const int resolution, InterpolationCache<T_SCALE>* interpolation) {
52   const Scaler scaler;
53   interpolation->lower.resize(out_size + 1);
54   interpolation->upper.resize(out_size + 1);
55   interpolation->lerp.resize(out_size + 1);
56   interpolation->ilerp.resize(out_size + 1);
57 
58   interpolation->lower[out_size] = 0;
59   interpolation->upper[out_size] = 0;
60   for (int64 i = out_size - 1; i >= 0; --i) {
61     const float in = scaler(i, scale);
62     const float in_f = std::floor(in);
63     interpolation->lower[i] =
64         std::max(static_cast<int64>(in_f), static_cast<int64>(0));
65     interpolation->upper[i] =
66         std::min(static_cast<int64>(std::ceil(in)), in_size - 1);
67     interpolation->lerp[i] = in - in_f;
68     interpolation->ilerp[i] =
69         static_cast<T_SCALE>((in - in_f) * (1 << resolution));
70   }
71 }
72 
73 template <typename T_SCALE>
BuildLerpCache(const int64 out_size,const int64 in_size,const float scale,const int index_step,const int resolution,const bool half_pixel_centers)74 inline InterpolationCache<T_SCALE> BuildLerpCache(
75     const int64 out_size, const int64 in_size, const float scale,
76     const int index_step, const int resolution, const bool half_pixel_centers) {
77   InterpolationCache<T_SCALE> cache;
78   // Compute the cached interpolation weights on the x and y dimensions.
79   if (half_pixel_centers) {
80     ComputeInterpolationWeights<T_SCALE, HalfPixelScaler>(
81         out_size, in_size, scale, resolution, &cache);
82   } else {
83     ComputeInterpolationWeights<T_SCALE, LegacyScaler>(out_size, in_size, scale,
84                                                        resolution, &cache);
85   }
86   CHECK(index_step > 0);
87   if (index_step > 1) {
88     for (int i = 0; i < cache.lower.size(); ++i) {
89       cache.lower[i] *= index_step;
90       cache.upper[i] *= index_step;
91     }
92   }
93   return cache;
94 }
95 
96 /**
97  * Computes the bilinear interpolation from the appropriate 4 float points
98  * and the linear interpolation weights.
99  */
100 template <typename T>
ComputeLerpReference(const T in_top_left,const T in_top_right,const T in_bottom_left,const T in_bottom_right,const float x_lerp,const float y_lerp,const float min,const float max)101 inline T ComputeLerpReference(const T in_top_left, const T in_top_right,
102                               const T in_bottom_left, const T in_bottom_right,
103                               const float x_lerp, const float y_lerp,
104                               const float min, const float max) {
105   const float top_left = QuantizedToFloat<T>(in_top_left, min, max);
106   const float top_right = QuantizedToFloat<T>(in_top_right, min, max);
107   const float bottom_left = QuantizedToFloat<T>(in_bottom_left, min, max);
108   const float bottom_right = QuantizedToFloat<T>(in_bottom_right, min, max);
109   const float top = top_left + (top_right - top_left) * x_lerp;
110   const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
111   const float out = top + (bottom - top) * y_lerp;
112   return FloatToQuantized<T>(out, min, max);
113 }
114 
115 template <typename T, typename T_SCALE, typename T_CALC>
MulOffset(T a,T b,T_SCALE c)116 inline T_CALC MulOffset(T a, T b, T_SCALE c) {
117   return (static_cast<T_CALC>(a) - static_cast<T_CALC>(b)) *
118          static_cast<T_CALC>(c);
119 }
120 
121 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
ComputeLerp(const T top_left,const T top_right,const T bottom_left,const T bottom_right,const T_SCALE x_lerp,const T_SCALE y_lerp)122 inline T ComputeLerp(const T top_left, const T top_right, const T bottom_left,
123                      const T bottom_right, const T_SCALE x_lerp,
124                      const T_SCALE y_lerp) {
125   constexpr T_CALC RESOLUTION_MULT = (1 << RESOLUTION);
126   const T_CALC top = static_cast<T_CALC>(top_left) * RESOLUTION_MULT +
127                      MulOffset<T, T_SCALE, T_CALC>(top_right, top_left, x_lerp);
128   const T_CALC bottom =
129       static_cast<T_CALC>(bottom_left) * RESOLUTION_MULT +
130       MulOffset<T, T_SCALE, T_CALC>(bottom_right, bottom_left, x_lerp);
131   const T_CALC out = top + (bottom - top) / RESOLUTION_MULT * y_lerp;
132   return static_cast<T>(
133       static_cast<int32>((out + RESOLUTION_MULT / 2) / RESOLUTION_MULT));
134 }
135 
136 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
ToUint8x8(const quint8 * v0,const quint8 * v1,const quint8 * v2,const quint8 * v3,const quint8 * v4,const quint8 * v5,const quint8 * v6,const quint8 * v7)137 inline uint8x8_t ToUint8x8(const quint8* v0, const quint8* v1, const quint8* v2,
138                            const quint8* v3, const quint8* v4, const quint8* v5,
139                            const quint8* v6, const quint8* v7) {
140   static const uint8x8_t ZERO_8x8 = vmov_n_u8(0);
141   uint8x8_t ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v0), ZERO_8x8, 0);
142   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v1), ret, 1);
143   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v2), ret, 2);
144   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v3), ret, 3);
145   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v4), ret, 4);
146   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v5), ret, 5);
147   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v6), ret, 6);
148   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v7), ret, 7);
149   return ret;
150 }
151 
ToInt16x8(const int16 * v0,const int16 * v1,const int16 * v2,const int16 * v3,const int16 * v4,const int16 * v5,const int16 * v6,const int16 * v7)152 inline int16x8_t ToInt16x8(const int16* v0, const int16* v1, const int16* v2,
153                            const int16* v3, const int16* v4, const int16* v5,
154                            const int16* v6, const int16* v7) {
155   static const int16x8_t ZERO_16x8 = vmovq_n_s16(0);
156   int16x8_t ret = vld1q_lane_s16(v0, ZERO_16x8, 0);
157   ret = vld1q_lane_s16(v1, ret, 1);
158   ret = vld1q_lane_s16(v2, ret, 2);
159   ret = vld1q_lane_s16(v3, ret, 3);
160   ret = vld1q_lane_s16(v4, ret, 4);
161   ret = vld1q_lane_s16(v5, ret, 5);
162   ret = vld1q_lane_s16(v6, ret, 6);
163   ret = vld1q_lane_s16(v7, ret, 7);
164   return ret;
165 }
166 
ToInt32x2(const qint32 * v0,const qint32 * v1)167 inline int32x2_t ToInt32x2(const qint32* v0, const qint32* v1) {
168   static const int32x2_t ZERO_32x2 = vmov_n_s32(0);
169   const int32x2_t ret0 =
170       vld1_lane_s32(reinterpret_cast<const int32*>(v0), ZERO_32x2, 0);
171   const int32x2_t ret1 =
172       vld1_lane_s32(reinterpret_cast<const int32*>(v1), ret0, 1);
173   return ret1;
174 }
175 
176 template <int RESOLUTION, bool X_LERP_SAME>
ComputeLerpx2(const qint32 * top_left0,const qint32 * top_right0,const qint32 * bottom_left0,const qint32 * bottom_right0,const qint32 * top_left1,const qint32 * top_right1,const qint32 * bottom_left1,const qint32 * bottom_right1,const int32 * x_lerp,const int32x2_t y_lerpsx)177 inline int32x2_t ComputeLerpx2(
178     const qint32* top_left0, const qint32* top_right0,
179     const qint32* bottom_left0, const qint32* bottom_right0,
180     const qint32* top_left1, const qint32* top_right1,
181     const qint32* bottom_left1, const qint32* bottom_right1,
182     const int32* x_lerp, const int32x2_t y_lerpsx) {
183   const int32x2_t x_lerpsx =
184       X_LERP_SAME ? vld1_dup_s32(reinterpret_cast<const int32*>(x_lerp))
185                   : vld1_s32(reinterpret_cast<const int32*>(x_lerp));
186 
187   const int32x2_t top_leftsx = ToInt32x2(top_left0, top_left1);
188   const int32x2_t top_rightsx = ToInt32x2(top_right0, top_right1);
189   const int32x2_t bottom_leftsx = ToInt32x2(bottom_left0, bottom_left1);
190   const int32x2_t bottom_rightsx = ToInt32x2(bottom_right0, bottom_right1);
191 
192   const int32x2_t retval =
193       ComputeLerp32x2<RESOLUTION>(top_leftsx, top_rightsx, bottom_leftsx,
194                                   bottom_rightsx, x_lerpsx, y_lerpsx);
195   return retval;
196 }
197 
198 template <int RESOLUTION>
ComputeLerpx8(const quint8 * tl0,const quint8 * tr0,const quint8 * bl0,const quint8 * br0,const int16 * xlp0,const quint8 * tl1,const quint8 * tr1,const quint8 * bl1,const quint8 * br1,const int16 * xlp1,const quint8 * tl2,const quint8 * tr2,const quint8 * bl2,const quint8 * br2,const int16 * xlp2,const quint8 * tl3,const quint8 * tr3,const quint8 * bl3,const quint8 * br3,const int16 * xlp3,const quint8 * tl4,const quint8 * tr4,const quint8 * bl4,const quint8 * br4,const int16 * xlp4,const quint8 * tl5,const quint8 * tr5,const quint8 * bl5,const quint8 * br5,const int16 * xlp5,const quint8 * tl6,const quint8 * tr6,const quint8 * bl6,const quint8 * br6,const int16 * xlp6,const quint8 * tl7,const quint8 * tr7,const quint8 * bl7,const quint8 * br7,const int16 * xlp7,const int16x8_t ys_lerpsx)199 inline uint8x8_t ComputeLerpx8(
200     const quint8* tl0, const quint8* tr0, const quint8* bl0, const quint8* br0,
201     const int16* xlp0, const quint8* tl1, const quint8* tr1, const quint8* bl1,
202     const quint8* br1, const int16* xlp1, const quint8* tl2, const quint8* tr2,
203     const quint8* bl2, const quint8* br2, const int16* xlp2, const quint8* tl3,
204     const quint8* tr3, const quint8* bl3, const quint8* br3, const int16* xlp3,
205     const quint8* tl4, const quint8* tr4, const quint8* bl4, const quint8* br4,
206     const int16* xlp4, const quint8* tl5, const quint8* tr5, const quint8* bl5,
207     const quint8* br5, const int16* xlp5, const quint8* tl6, const quint8* tr6,
208     const quint8* bl6, const quint8* br6, const int16* xlp6, const quint8* tl7,
209     const quint8* tr7, const quint8* bl7, const quint8* br7, const int16* xlp7,
210     const int16x8_t ys_lerpsx) {
211   const uint8x8_t tl8x8 = ToUint8x8(tl0, tl1, tl2, tl3, tl4, tl5, tl6, tl7);
212   const uint8x8_t tr8x8 = ToUint8x8(tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7);
213   const uint8x8_t bl8x8 = ToUint8x8(bl0, bl1, bl2, bl3, bl4, bl5, bl6, bl7);
214   const uint8x8_t br8x8 = ToUint8x8(br0, br1, br2, br3, br4, br5, br6, br7);
215   const int16x8_t xs_lerpsx =
216       ToInt16x8(xlp0, xlp1, xlp2, xlp3, xlp4, xlp5, xlp6, xlp7);
217   return ComputeLerp8x8<RESOLUTION>(tl8x8, tr8x8, bl8x8, br8x8, xs_lerpsx,
218                                     ys_lerpsx);
219 }
220 
221 // Expand address at compile time to improve performance
222 template <int RESOLUTION, int ID0, int CH0, int ID1, int CH1, int ID2, int CH2,
223           int ID3, int CH3, int ID4, int CH4, int ID5, int CH5, int ID6,
224           int CH6, int ID7, int CH7>
ComputeLerpx8Tmpl(const quint8 * const yl,const quint8 * yu,const int64 * xl,const int64 * xu,const int16 * xlp,const int16x8_t ys_lerpsx)225 inline uint8x8_t ComputeLerpx8Tmpl(const quint8* const yl, const quint8* yu,
226                                    const int64* xl, const int64* xu,
227                                    const int16* xlp,
228                                    const int16x8_t ys_lerpsx) {
229   return ComputeLerpx8<RESOLUTION>(
230       yl + xl[ID0] + CH0, yl + xu[ID0] + CH0, yu + xl[ID0] + CH0,
231       yu + xu[ID0] + CH0, xlp + ID0, yl + xl[ID1] + CH1, yl + xu[ID1] + CH1,
232       yu + xl[ID1] + CH1, yu + xu[ID1] + CH1, xlp + ID1, yl + xl[ID2] + CH2,
233       yl + xu[ID2] + CH2, yu + xl[ID2] + CH2, yu + xu[ID2] + CH2, xlp + ID2,
234       yl + xl[ID3] + CH3, yl + xu[ID3] + CH3, yu + xl[ID3] + CH3,
235       yu + xu[ID3] + CH3, xlp + ID3, yl + xl[ID4] + CH4, yl + xu[ID4] + CH4,
236       yu + xl[ID4] + CH4, yu + xu[ID4] + CH4, xlp + ID4, yl + xl[ID5] + CH5,
237       yl + xu[ID5] + CH5, yu + xl[ID5] + CH5, yu + xu[ID5] + CH5, xlp + ID5,
238       yl + xl[ID6] + CH6, yl + xu[ID6] + CH6, yu + xl[ID6] + CH6,
239       yu + xu[ID6] + CH6, xlp + ID6, yl + xl[ID7] + CH7, yl + xu[ID7] + CH7,
240       yu + xl[ID7] + CH7, yu + xu[ID7] + CH7, xlp + ID7, ys_lerpsx);
241 }
242 
243 #endif
244 
245 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
OutputLerpForChannels(const InterpolationCache<T_SCALE> & xs,const int64 x,const T_SCALE ys_ilerp,const int channels,const float min,const float max,const T * ys_input_lower_ptr,const T * ys_input_upper_ptr,T * output_y_ptr)246 inline void OutputLerpForChannels(const InterpolationCache<T_SCALE>& xs,
247                                   const int64 x, const T_SCALE ys_ilerp,
248                                   const int channels, const float min,
249                                   const float max, const T* ys_input_lower_ptr,
250                                   const T* ys_input_upper_ptr,
251                                   T* output_y_ptr) {
252   const int64 xs_lower = xs.lower[x];
253   const int64 xs_upper = xs.upper[x];
254   const T_SCALE xs_ilerp = xs.ilerp[x];
255   for (int c = 0; c < channels; ++c) {
256     const T top_left = ys_input_lower_ptr[xs_lower + c];
257     const T top_right = ys_input_lower_ptr[xs_upper + c];
258     const T bottom_left = ys_input_upper_ptr[xs_lower + c];
259     const T bottom_right = ys_input_upper_ptr[xs_upper + c];
260     const T val = ComputeLerp<RESOLUTION, T, T_SCALE, T_CALC>(
261         top_left, top_right, bottom_left, bottom_right, xs_ilerp, ys_ilerp);
262     output_y_ptr[x * channels + c] = val;
263   }
264 }
265 
266 template <int RES>
OutputLerp8x8x1(const InterpolationCache<int16> & xs,const int64 x_start,const int16 ys_ilerp,const float min,const float max,const quint8 * const ys_input_lower_ptr,const quint8 * const ys_input_upper_ptr,quint8 * output_y_ptr)267 inline void OutputLerp8x8x1(const InterpolationCache<int16>& xs,
268                             const int64 x_start, const int16 ys_ilerp,
269                             const float min, const float max,
270                             const quint8* const ys_input_lower_ptr,
271                             const quint8* const ys_input_upper_ptr,
272                             quint8* output_y_ptr) {
273 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
274   const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
275 
276   const uint8x8_t x0x7 =
277       ComputeLerpx8Tmpl<RES, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0>(
278           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
279           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
280 
281   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start), x0x7);
282 
283 #else
284   for (int x = x_start; x < x_start + 8; ++x) {
285     OutputLerpForChannels<RES, quint8, int16, int16>(
286         xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
287         output_y_ptr);
288   }
289 #endif
290 }
291 
292 template <int RES>
OutputLerp8x8x3(const InterpolationCache<int16> & xs,const int64 x_start,const int16 ys_ilerp,const float min,const float max,const quint8 * const ys_input_lower_ptr,const quint8 * const ys_input_upper_ptr,quint8 * output_y_ptr)293 inline void OutputLerp8x8x3(const InterpolationCache<int16>& xs,
294                             const int64 x_start, const int16 ys_ilerp,
295                             const float min, const float max,
296                             const quint8* const ys_input_lower_ptr,
297                             const quint8* const ys_input_upper_ptr,
298                             quint8* output_y_ptr) {
299 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
300   const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
301 
302   const uint8x8_t x0c0x2c1 =
303       ComputeLerpx8Tmpl<RES, 0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1>(
304           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
305           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
306 
307   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3), x0c0x2c1);
308 
309   const uint8x8_t x2c2x5c0 =
310       ComputeLerpx8Tmpl<RES, 2, 2, 3, 0, 3, 1, 3, 2, 4, 0, 4, 1, 4, 2, 5, 0>(
311           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
312           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
313 
314   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 8), x2c2x5c0);
315 
316   const uint8x8_t x5c1x7c2 =
317       ComputeLerpx8Tmpl<RES, 5, 1, 5, 2, 6, 0, 6, 1, 6, 2, 7, 0, 7, 1, 7, 2>(
318           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
319           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
320 
321   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 16),
322           x5c1x7c2);
323 
324 #else
325   for (int x = x_start; x < x_start + 8; ++x) {
326     OutputLerpForChannels<RES, quint8, int16, int16>(
327         xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
328         output_y_ptr);
329   }
330 #endif
331 }
332 
333 template <int RESOLUTION>
OutputLerp32x4x1(const InterpolationCache<int32> & xs,const int64 x_start,const int32 ys_ilerp,const float min,const float max,const qint32 * const ys_input_lower_ptr,const qint32 * const ys_input_upper_ptr,qint32 * output_y_ptr)334 inline void OutputLerp32x4x1(const InterpolationCache<int32>& xs,
335                              const int64 x_start, const int32 ys_ilerp,
336                              const float min, const float max,
337                              const qint32* const ys_input_lower_ptr,
338                              const qint32* const ys_input_upper_ptr,
339                              qint32* output_y_ptr) {
340 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
341   const int64 xs_lower0 = xs.lower[x_start];
342   const int64 xs_upper0 = xs.upper[x_start];
343   const int32* const xs_ilerp0 = &xs.ilerp[x_start];
344   const int64 xs_lower1 = xs.lower[x_start + 1];
345   const int64 xs_upper1 = xs.upper[x_start + 1];
346   const int64 xs_lower2 = xs.lower[x_start + 2];
347   const int64 xs_upper2 = xs.upper[x_start + 2];
348   const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
349   const int64 xs_lower3 = xs.lower[x_start + 3];
350   const int64 xs_upper3 = xs.upper[x_start + 3];
351 
352   const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
353 
354   const int32x2_t x0x1 = ComputeLerpx2<RESOLUTION, false>(
355       ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
356       ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
357       ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
358       ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
359       y_lerpsx);
360 
361   const int32x2_t x1x2 = ComputeLerpx2<RESOLUTION, false>(
362       ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
363       ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
364       ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
365       ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
366       y_lerpsx);
367 
368   const int32x4_t x0x1x2x3 = vcombine_s32(x0x1, x1x2);
369 
370   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start), x0x1x2x3);
371 
372 #else
373   for (int x = x_start; x < x_start + 4; ++x) {
374     OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
375         xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
376         output_y_ptr);
377   }
378 #endif
379 }
380 
381 template <int RESOLUTION>
OutputLerp32x4x3(const InterpolationCache<int32> & xs,const int64 x_start,const int32 ys_ilerp,const float min,const float max,const qint32 * const ys_input_lower_ptr,const qint32 * const ys_input_upper_ptr,qint32 * output_y_ptr)382 inline void OutputLerp32x4x3(const InterpolationCache<int32>& xs,
383                              const int64 x_start, const int32 ys_ilerp,
384                              const float min, const float max,
385                              const qint32* const ys_input_lower_ptr,
386                              const qint32* const ys_input_upper_ptr,
387                              qint32* output_y_ptr) {
388 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
389   const int64 xs_lower0 = xs.lower[x_start];
390   const int64 xs_upper0 = xs.upper[x_start];
391   const int32* const xs_ilerp0 = &xs.ilerp[x_start];
392   const int64 xs_lower1 = xs.lower[x_start + 1];
393   const int64 xs_upper1 = xs.upper[x_start + 1];
394   const int32* const xs_ilerp1 = &xs.ilerp[x_start + 1];
395   const int64 xs_lower2 = xs.lower[x_start + 2];
396   const int64 xs_upper2 = xs.upper[x_start + 2];
397   const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
398   const int64 xs_lower3 = xs.lower[x_start + 3];
399   const int64 xs_upper3 = xs.upper[x_start + 3];
400   const int32* const xs_ilerp3 = &xs.ilerp[x_start + 3];
401 
402   const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
403 
404   const int32x2_t x0c0x0c1 = ComputeLerpx2<RESOLUTION, true>(
405       ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
406       ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
407       ys_input_lower_ptr + xs_lower0 + 1, ys_input_lower_ptr + xs_upper0 + 1,
408       ys_input_upper_ptr + xs_lower0 + 1, ys_input_upper_ptr + xs_upper0 + 1,
409       xs_ilerp0, y_lerpsx);
410 
411   const int32x2_t x0c2x1c0 = ComputeLerpx2<RESOLUTION, false>(
412       ys_input_lower_ptr + xs_lower0 + 2, ys_input_lower_ptr + xs_upper0 + 2,
413       ys_input_upper_ptr + xs_lower0 + 2, ys_input_upper_ptr + xs_upper0 + 2,
414       ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
415       ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
416       y_lerpsx);
417 
418   const int32x2_t x1c1x1c2 = ComputeLerpx2<RESOLUTION, true>(
419       ys_input_lower_ptr + xs_lower1 + 1, ys_input_lower_ptr + xs_upper1 + 1,
420       ys_input_upper_ptr + xs_lower1 + 1, ys_input_upper_ptr + xs_upper1 + 1,
421       ys_input_lower_ptr + xs_lower1 + 2, ys_input_lower_ptr + xs_upper1 + 2,
422       ys_input_upper_ptr + xs_lower1 + 2, ys_input_upper_ptr + xs_upper1 + 2,
423       xs_ilerp1, y_lerpsx);
424 
425   const int32x2_t x2c0x2c1 = ComputeLerpx2<RESOLUTION, true>(
426       ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
427       ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
428       ys_input_lower_ptr + xs_lower2 + 1, ys_input_lower_ptr + xs_upper2 + 1,
429       ys_input_upper_ptr + xs_lower2 + 1, ys_input_upper_ptr + xs_upper2 + 1,
430       xs_ilerp2, y_lerpsx);
431 
432   const int32x2_t x2c2x3c0 = ComputeLerpx2<RESOLUTION, false>(
433       ys_input_lower_ptr + xs_lower2 + 2, ys_input_lower_ptr + xs_upper2 + 2,
434       ys_input_upper_ptr + xs_lower2 + 2, ys_input_upper_ptr + xs_upper2 + 2,
435       ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
436       ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
437       y_lerpsx);
438 
439   const int32x2_t x3c1x3c2 = ComputeLerpx2<RESOLUTION, true>(
440       ys_input_lower_ptr + xs_lower3 + 1, ys_input_lower_ptr + xs_upper3 + 1,
441       ys_input_upper_ptr + xs_lower3 + 1, ys_input_upper_ptr + xs_upper3 + 1,
442       ys_input_lower_ptr + xs_lower3 + 2, ys_input_lower_ptr + xs_upper3 + 2,
443       ys_input_upper_ptr + xs_lower3 + 2, ys_input_upper_ptr + xs_upper3 + 2,
444       xs_ilerp3, y_lerpsx);
445 
446   const int32x4_t x0c0x0c1x0c2x1c0 = vcombine_s32(x0c0x0c1, x0c2x1c0);
447   const int32x4_t x1c1x1c2x2c0x2c1 = vcombine_s32(x1c1x1c2, x2c0x2c1);
448   const int32x4_t x2c2x3c0x3c1x3c2 = vcombine_s32(x2c2x3c0, x3c1x3c2);
449 
450   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3),
451             x0c0x0c1x0c2x1c0);
452   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 4),
453             x1c1x1c2x2c0x2c1);
454   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 8),
455             x2c2x3c0x3c1x3c2);
456 
457 #else
458   for (int x = x_start; x < x_start + 4; ++x) {
459     OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
460         xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
461         output_y_ptr);
462   }
463 #endif
464 }
465 
466 template <typename T>
ResizeImageReference(typename TTypes<T,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)467 void ResizeImageReference(typename TTypes<T, 4>::ConstTensor images,
468                           const int batch_size, const int64 in_height,
469                           const int64 in_width, const int64 out_height,
470                           const int64 out_width, const int channels,
471                           const float height_scale, const float width_scale,
472                           const float in_min, const float in_max,
473                           const bool half_pixel_centers,
474                           typename TTypes<T, 4>::Tensor* output) {
475   CHECK_NOTNULL(output);
476 
477   const InterpolationCache<float> xs = BuildLerpCache<float>(
478       out_width, in_width, width_scale, channels, 0, half_pixel_centers);
479   const InterpolationCache<float> ys = BuildLerpCache<float>(
480       out_height, in_height, height_scale, 1, 0, half_pixel_centers);
481 
482   const int64 in_row_size = in_width * channels;
483   const int64 in_batch_num_values = in_height * in_row_size;
484   const int64 out_row_size = out_width * channels;
485 
486   const T* input_b_ptr = images.data();
487 
488   T* output_y_ptr = output->data();
489   for (int b = 0; b < batch_size; ++b) {
490     for (int64 y = 0; y < out_height; ++y) {
491       const T* ys_input_lower_ptr = input_b_ptr + ys.lower[y] * in_row_size;
492       const T* ys_input_upper_ptr = input_b_ptr + ys.upper[y] * in_row_size;
493       const float ys_lerp = ys.lerp[y];
494       for (int64 x = 0; x < out_width; ++x) {
495         const int64 xs_lower = xs.lower[x];
496         const int64 xs_upper = xs.upper[x];
497         const float xs_lerp = xs.lerp[x];
498         for (int c = 0; c < channels; ++c) {
499           const T top_left = ys_input_lower_ptr[xs_lower + c];
500           const T top_right = ys_input_lower_ptr[xs_upper + c];
501           const T bottom_left = ys_input_upper_ptr[xs_lower + c];
502           const T bottom_right = ys_input_upper_ptr[xs_upper + c];
503           const T val = ComputeLerpReference<T>(
504               top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp,
505               in_min, in_max);
506           output_y_ptr[x * channels + c] = val;
507         }
508       }
509       output_y_ptr += out_row_size;
510     }
511     input_b_ptr += in_batch_num_values;
512   }
513 }
514 
515 template <typename T>
ResizeImage(typename TTypes<T,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)516 void ResizeImage(typename TTypes<T, 4>::ConstTensor images,
517                  const int batch_size, const int64 in_height,
518                  const int64 in_width, const int64 out_height,
519                  const int64 out_width, const int channels,
520                  const float height_scale, const float width_scale,
521                  const float in_min, const float in_max,
522                  const bool half_pixel_centers,
523                  typename TTypes<T, 4>::Tensor* output) {
524   ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
525                           out_width, channels, height_scale, width_scale,
526                           in_min, in_max, half_pixel_centers, output);
527 }
528 
529 template <>
ResizeImage(typename TTypes<qint32,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<qint32,4>::Tensor * output)530 void ResizeImage<qint32>(typename TTypes<qint32, 4>::ConstTensor images,
531                          const int batch_size, const int64 in_height,
532                          const int64 in_width, const int64 out_height,
533                          const int64 out_width, const int channels,
534                          const float height_scale, const float width_scale,
535                          const float in_min, const float in_max,
536                          const bool half_pixel_centers,
537                          typename TTypes<qint32, 4>::Tensor* output) {
538   // 30 is maximum resolution for signed int.
539   constexpr int RESOLUTION = 30;
540   constexpr int SIMD_STEP = 4;
541 
542   CHECK_NOTNULL(output);
543 
544   const InterpolationCache<int32> xs =
545       BuildLerpCache<int32>(out_width, in_width, width_scale, channels,
546                             RESOLUTION, half_pixel_centers);
547   const InterpolationCache<int32> ys = BuildLerpCache<int32>(
548       out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
549 
550   const int64 in_row_size = in_width * channels;
551   const int64 in_batch_num_values = in_height * in_row_size;
552   const int64 out_row_size = out_width * channels;
553 
554   const qint32* input_b_ptr = images.data();
555 
556   qint32* output_y_ptr = output->data();
557 
558   for (int b = 0; b < batch_size; ++b) {
559     for (int64 y = 0; y < out_height; ++y) {
560       const qint32* ys_input_lower_ptr =
561           input_b_ptr + ys.lower[y] * in_row_size;
562       const qint32* ys_input_upper_ptr =
563           input_b_ptr + ys.upper[y] * in_row_size;
564       const int32 ys_ilerp = ys.ilerp[y];
565       // Optimized for channels == 1 or channels == 3 as this
566       // is typical channels.
567       int64 x = 0;
568       if (channels == 1) {
569         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
570           OutputLerp32x4x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
571                                        ys_input_lower_ptr, ys_input_upper_ptr,
572                                        output_y_ptr);
573         }
574       } else if (channels == 3) {
575         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
576           OutputLerp32x4x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
577                                        ys_input_lower_ptr, ys_input_upper_ptr,
578                                        output_y_ptr);
579         }
580       }
581       for (; x < out_width; ++x) {
582         OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
583             xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
584             ys_input_upper_ptr, output_y_ptr);
585       }
586       output_y_ptr += out_row_size;
587     }
588     input_b_ptr += in_batch_num_values;
589   }
590 }
591 
592 template <>
ResizeImage(typename TTypes<quint8,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<quint8,4>::Tensor * output)593 void ResizeImage<quint8>(typename TTypes<quint8, 4>::ConstTensor images,
594                          const int batch_size, const int64 in_height,
595                          const int64 in_width, const int64 out_height,
596                          const int64 out_width, const int channels,
597                          const float height_scale, const float width_scale,
598                          const float in_min, const float in_max,
599                          const bool half_pixel_centers,
600                          typename TTypes<quint8, 4>::Tensor* output) {
601   // 7 is maximum resolution for unsigned byte.
602   constexpr int RESOLUTION = 7;
603   constexpr int SIMD_STEP = 8;
604 
605   CHECK_NOTNULL(output);
606 
607   const InterpolationCache<int16> xs =
608       BuildLerpCache<int16>(out_width, in_width, width_scale, channels,
609                             RESOLUTION, half_pixel_centers);
610   const InterpolationCache<int16> ys = BuildLerpCache<int16>(
611       out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
612 
613   const int64 in_row_size = in_width * channels;
614   const int64 in_batch_num_values = in_height * in_row_size;
615   const int64 out_row_size = out_width * channels;
616 
617   const quint8* input_b_ptr = images.data();
618 
619   quint8* output_y_ptr = output->data();
620 
621   for (int b = 0; b < batch_size; ++b) {
622     for (int64 y = 0; y < out_height; ++y) {
623       const quint8* ys_input_lower_ptr =
624           input_b_ptr + ys.lower[y] * in_row_size;
625       const quint8* ys_input_upper_ptr =
626           input_b_ptr + ys.upper[y] * in_row_size;
627       const int32 ys_ilerp = ys.ilerp[y];
628       // Optimized for channels == 1 or channels == 3 as this
629       // is typical channels.
630       // TODO(satok): Support more generic NEON optimized implementation
631       // for different channels.
632       int64 x = 0;
633       if (channels == 1) {
634         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
635           OutputLerp8x8x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
636                                       ys_input_lower_ptr, ys_input_upper_ptr,
637                                       output_y_ptr);
638         }
639       } else if (channels == 3) {
640         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
641           OutputLerp8x8x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
642                                       ys_input_lower_ptr, ys_input_upper_ptr,
643                                       output_y_ptr);
644         }
645       }
646       for (; x < out_width; ++x) {
647         OutputLerpForChannels<RESOLUTION, quint8, int16, int16>(
648             xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
649             ys_input_upper_ptr, output_y_ptr);
650       }
651       output_y_ptr += out_row_size;
652     }
653     input_b_ptr += in_batch_num_values;
654   }
655 }
656 
657 template <typename T>
ResizeBilinear(const typename TTypes<T,4>::ConstTensor & images,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)658 void ResizeBilinear(const typename TTypes<T, 4>::ConstTensor& images,
659                     const float height_scale, const float width_scale,
660                     const float in_min, const float in_max,
661                     const bool half_pixel_centers,
662                     typename TTypes<T, 4>::Tensor* output) {
663   CHECK_NOTNULL(output);
664 
665   const int batch_size = images.dimension(0);
666   const int64 in_height = images.dimension(1);
667   const int64 in_width = images.dimension(2);
668   const int channels = images.dimension(3);
669 
670   const int64 out_height = output->dimension(1);
671   const int64 out_width = output->dimension(2);
672 
673   // Handle no-op resizes efficiently.
674   if (out_height == in_height && out_width == in_width) {
675     *output = images.template cast<T>();
676     return;
677   }
678 
679   if (USE_REFERENCE) {
680     ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
681                             out_width, channels, height_scale, width_scale,
682                             in_min, in_max, half_pixel_centers, output);
683   } else {
684     ResizeImage<T>(images, batch_size, in_height, in_width, out_height,
685                    out_width, channels, height_scale, width_scale, in_min,
686                    in_max, half_pixel_centers, output);
687   }
688 }
689 
690 }  // namespace
691 
692 template <class T>
693 class QuantizedResizeBilinearOp : public OpKernel {
694  public:
QuantizedResizeBilinearOp(OpKernelConstruction * context)695   explicit QuantizedResizeBilinearOp(OpKernelConstruction* context)
696       : OpKernel(context) {
697     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
698     OP_REQUIRES_OK(
699         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
700   }
701 
Compute(OpKernelContext * context)702   void Compute(OpKernelContext* context) override {
703     const Tensor& input = context->input(0);
704     const float in_min = context->input(2).flat<float>()(0);
705     const float in_max = context->input(3).flat<float>()(0);
706 
707     ImageResizerState st(align_corners_, false);
708     st.ValidateAndCreateOutput(context, input);
709 
710     if (!context->status().ok()) return;
711 
712     // Return if the output is empty.
713     if (st.output->NumElements() == 0) return;
714 
715     typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>());
716     typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
717 
718     ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min,
719                       in_max, half_pixel_centers_, &output_data);
720     Tensor* out_min = nullptr;
721     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &out_min));
722     out_min->flat<float>()(0) = in_min;
723 
724     Tensor* out_max = nullptr;
725     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &out_max));
726     out_max->flat<float>()(0) = in_max;
727   }
728 
729  private:
730   bool align_corners_;
731   bool half_pixel_centers_;
732 
733   TF_DISALLOW_COPY_AND_ASSIGN(QuantizedResizeBilinearOp<T>);
734 };
735 
736 #define REGISTER_CPU_KERNEL(type)                         \
737   REGISTER_KERNEL_BUILDER(Name("QuantizedResizeBilinear") \
738                               .Device(DEVICE_CPU)         \
739                               .HostMemory("size")         \
740                               .TypeConstraint<type>("T"), \
741                           QuantizedResizeBilinearOp<type>)
742 
743 REGISTER_CPU_KERNEL(::tensorflow::quint8);
744 REGISTER_CPU_KERNEL(::tensorflow::qint32);
745 REGISTER_CPU_KERNEL(float);
746 
747 }  // namespace tensorflow
748