1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements a quantized version of the resize bilinear op.
17
18 #define EIGEN_USE_THREADS
19
20 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
21 #define USE_NEON
22 #define QUANTIZED_RESIZE_BILINEAR_USE_NEON
23 #include <arm_neon.h>
24 #endif
25
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/image_resizer_state.h"
29 #include "tensorflow/core/kernels/quantization_utils.h"
30 #include "tensorflow/core/platform/macros.h"
31
32 namespace tensorflow {
33
34 static constexpr bool USE_REFERENCE = false;
35
36 namespace {
37 // Compute the interpolation indices only once.
38 template <typename T_SCALE>
39 struct InterpolationCache {
40 std::vector<int64> lower; // Lower source index used in the interpolation
41 std::vector<int64> upper; // Upper source index used in the interpolation
42 // 1-D linear iterpolation scale (see:
43 // https://en.wikipedia.org/wiki/Bilinear_interpolation)
44 std::vector<float> lerp;
45 std::vector<T_SCALE> ilerp;
46 };
47
48 template <typename T_SCALE, typename Scaler>
ComputeInterpolationWeights(const int64 out_size,const int64 in_size,const float scale,const int resolution,InterpolationCache<T_SCALE> * interpolation)49 inline void ComputeInterpolationWeights(
50 const int64 out_size, const int64 in_size, const float scale,
51 const int resolution, InterpolationCache<T_SCALE>* interpolation) {
52 const Scaler scaler;
53 interpolation->lower.resize(out_size + 1);
54 interpolation->upper.resize(out_size + 1);
55 interpolation->lerp.resize(out_size + 1);
56 interpolation->ilerp.resize(out_size + 1);
57
58 interpolation->lower[out_size] = 0;
59 interpolation->upper[out_size] = 0;
60 for (int64 i = out_size - 1; i >= 0; --i) {
61 const float in = scaler(i, scale);
62 const float in_f = std::floor(in);
63 interpolation->lower[i] =
64 std::max(static_cast<int64>(in_f), static_cast<int64>(0));
65 interpolation->upper[i] =
66 std::min(static_cast<int64>(std::ceil(in)), in_size - 1);
67 interpolation->lerp[i] = in - in_f;
68 interpolation->ilerp[i] =
69 static_cast<T_SCALE>((in - in_f) * (1 << resolution));
70 }
71 }
72
73 template <typename T_SCALE>
BuildLerpCache(const int64 out_size,const int64 in_size,const float scale,const int index_step,const int resolution,const bool half_pixel_centers)74 inline InterpolationCache<T_SCALE> BuildLerpCache(
75 const int64 out_size, const int64 in_size, const float scale,
76 const int index_step, const int resolution, const bool half_pixel_centers) {
77 InterpolationCache<T_SCALE> cache;
78 // Compute the cached interpolation weights on the x and y dimensions.
79 if (half_pixel_centers) {
80 ComputeInterpolationWeights<T_SCALE, HalfPixelScaler>(
81 out_size, in_size, scale, resolution, &cache);
82 } else {
83 ComputeInterpolationWeights<T_SCALE, LegacyScaler>(out_size, in_size, scale,
84 resolution, &cache);
85 }
86 CHECK(index_step > 0);
87 if (index_step > 1) {
88 for (int i = 0; i < cache.lower.size(); ++i) {
89 cache.lower[i] *= index_step;
90 cache.upper[i] *= index_step;
91 }
92 }
93 return cache;
94 }
95
96 /**
97 * Computes the bilinear interpolation from the appropriate 4 float points
98 * and the linear interpolation weights.
99 */
100 template <typename T>
ComputeLerpReference(const T in_top_left,const T in_top_right,const T in_bottom_left,const T in_bottom_right,const float x_lerp,const float y_lerp,const float min,const float max)101 inline T ComputeLerpReference(const T in_top_left, const T in_top_right,
102 const T in_bottom_left, const T in_bottom_right,
103 const float x_lerp, const float y_lerp,
104 const float min, const float max) {
105 const float top_left = QuantizedToFloat<T>(in_top_left, min, max);
106 const float top_right = QuantizedToFloat<T>(in_top_right, min, max);
107 const float bottom_left = QuantizedToFloat<T>(in_bottom_left, min, max);
108 const float bottom_right = QuantizedToFloat<T>(in_bottom_right, min, max);
109 const float top = top_left + (top_right - top_left) * x_lerp;
110 const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
111 const float out = top + (bottom - top) * y_lerp;
112 return FloatToQuantized<T>(out, min, max);
113 }
114
115 template <typename T, typename T_SCALE, typename T_CALC>
MulOffset(T a,T b,T_SCALE c)116 inline T_CALC MulOffset(T a, T b, T_SCALE c) {
117 return (static_cast<T_CALC>(a) - static_cast<T_CALC>(b)) *
118 static_cast<T_CALC>(c);
119 }
120
121 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
ComputeLerp(const T top_left,const T top_right,const T bottom_left,const T bottom_right,const T_SCALE x_lerp,const T_SCALE y_lerp)122 inline T ComputeLerp(const T top_left, const T top_right, const T bottom_left,
123 const T bottom_right, const T_SCALE x_lerp,
124 const T_SCALE y_lerp) {
125 constexpr T_CALC RESOLUTION_MULT = (1 << RESOLUTION);
126 const T_CALC top = static_cast<T_CALC>(top_left) * RESOLUTION_MULT +
127 MulOffset<T, T_SCALE, T_CALC>(top_right, top_left, x_lerp);
128 const T_CALC bottom =
129 static_cast<T_CALC>(bottom_left) * RESOLUTION_MULT +
130 MulOffset<T, T_SCALE, T_CALC>(bottom_right, bottom_left, x_lerp);
131 const T_CALC out = top + (bottom - top) / RESOLUTION_MULT * y_lerp;
132 return static_cast<T>(
133 static_cast<int32>((out + RESOLUTION_MULT / 2) / RESOLUTION_MULT));
134 }
135
136 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
ToUint8x8(const quint8 * v0,const quint8 * v1,const quint8 * v2,const quint8 * v3,const quint8 * v4,const quint8 * v5,const quint8 * v6,const quint8 * v7)137 inline uint8x8_t ToUint8x8(const quint8* v0, const quint8* v1, const quint8* v2,
138 const quint8* v3, const quint8* v4, const quint8* v5,
139 const quint8* v6, const quint8* v7) {
140 static const uint8x8_t ZERO_8x8 = vmov_n_u8(0);
141 uint8x8_t ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v0), ZERO_8x8, 0);
142 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v1), ret, 1);
143 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v2), ret, 2);
144 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v3), ret, 3);
145 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v4), ret, 4);
146 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v5), ret, 5);
147 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v6), ret, 6);
148 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v7), ret, 7);
149 return ret;
150 }
151
ToInt16x8(const int16 * v0,const int16 * v1,const int16 * v2,const int16 * v3,const int16 * v4,const int16 * v5,const int16 * v6,const int16 * v7)152 inline int16x8_t ToInt16x8(const int16* v0, const int16* v1, const int16* v2,
153 const int16* v3, const int16* v4, const int16* v5,
154 const int16* v6, const int16* v7) {
155 static const int16x8_t ZERO_16x8 = vmovq_n_s16(0);
156 int16x8_t ret = vld1q_lane_s16(v0, ZERO_16x8, 0);
157 ret = vld1q_lane_s16(v1, ret, 1);
158 ret = vld1q_lane_s16(v2, ret, 2);
159 ret = vld1q_lane_s16(v3, ret, 3);
160 ret = vld1q_lane_s16(v4, ret, 4);
161 ret = vld1q_lane_s16(v5, ret, 5);
162 ret = vld1q_lane_s16(v6, ret, 6);
163 ret = vld1q_lane_s16(v7, ret, 7);
164 return ret;
165 }
166
ToInt32x2(const qint32 * v0,const qint32 * v1)167 inline int32x2_t ToInt32x2(const qint32* v0, const qint32* v1) {
168 static const int32x2_t ZERO_32x2 = vmov_n_s32(0);
169 const int32x2_t ret0 =
170 vld1_lane_s32(reinterpret_cast<const int32*>(v0), ZERO_32x2, 0);
171 const int32x2_t ret1 =
172 vld1_lane_s32(reinterpret_cast<const int32*>(v1), ret0, 1);
173 return ret1;
174 }
175
176 template <int RESOLUTION, bool X_LERP_SAME>
ComputeLerpx2(const qint32 * top_left0,const qint32 * top_right0,const qint32 * bottom_left0,const qint32 * bottom_right0,const qint32 * top_left1,const qint32 * top_right1,const qint32 * bottom_left1,const qint32 * bottom_right1,const int32 * x_lerp,const int32x2_t y_lerpsx)177 inline int32x2_t ComputeLerpx2(
178 const qint32* top_left0, const qint32* top_right0,
179 const qint32* bottom_left0, const qint32* bottom_right0,
180 const qint32* top_left1, const qint32* top_right1,
181 const qint32* bottom_left1, const qint32* bottom_right1,
182 const int32* x_lerp, const int32x2_t y_lerpsx) {
183 const int32x2_t x_lerpsx =
184 X_LERP_SAME ? vld1_dup_s32(reinterpret_cast<const int32*>(x_lerp))
185 : vld1_s32(reinterpret_cast<const int32*>(x_lerp));
186
187 const int32x2_t top_leftsx = ToInt32x2(top_left0, top_left1);
188 const int32x2_t top_rightsx = ToInt32x2(top_right0, top_right1);
189 const int32x2_t bottom_leftsx = ToInt32x2(bottom_left0, bottom_left1);
190 const int32x2_t bottom_rightsx = ToInt32x2(bottom_right0, bottom_right1);
191
192 const int32x2_t retval =
193 ComputeLerp32x2<RESOLUTION>(top_leftsx, top_rightsx, bottom_leftsx,
194 bottom_rightsx, x_lerpsx, y_lerpsx);
195 return retval;
196 }
197
198 template <int RESOLUTION>
ComputeLerpx8(const quint8 * tl0,const quint8 * tr0,const quint8 * bl0,const quint8 * br0,const int16 * xlp0,const quint8 * tl1,const quint8 * tr1,const quint8 * bl1,const quint8 * br1,const int16 * xlp1,const quint8 * tl2,const quint8 * tr2,const quint8 * bl2,const quint8 * br2,const int16 * xlp2,const quint8 * tl3,const quint8 * tr3,const quint8 * bl3,const quint8 * br3,const int16 * xlp3,const quint8 * tl4,const quint8 * tr4,const quint8 * bl4,const quint8 * br4,const int16 * xlp4,const quint8 * tl5,const quint8 * tr5,const quint8 * bl5,const quint8 * br5,const int16 * xlp5,const quint8 * tl6,const quint8 * tr6,const quint8 * bl6,const quint8 * br6,const int16 * xlp6,const quint8 * tl7,const quint8 * tr7,const quint8 * bl7,const quint8 * br7,const int16 * xlp7,const int16x8_t ys_lerpsx)199 inline uint8x8_t ComputeLerpx8(
200 const quint8* tl0, const quint8* tr0, const quint8* bl0, const quint8* br0,
201 const int16* xlp0, const quint8* tl1, const quint8* tr1, const quint8* bl1,
202 const quint8* br1, const int16* xlp1, const quint8* tl2, const quint8* tr2,
203 const quint8* bl2, const quint8* br2, const int16* xlp2, const quint8* tl3,
204 const quint8* tr3, const quint8* bl3, const quint8* br3, const int16* xlp3,
205 const quint8* tl4, const quint8* tr4, const quint8* bl4, const quint8* br4,
206 const int16* xlp4, const quint8* tl5, const quint8* tr5, const quint8* bl5,
207 const quint8* br5, const int16* xlp5, const quint8* tl6, const quint8* tr6,
208 const quint8* bl6, const quint8* br6, const int16* xlp6, const quint8* tl7,
209 const quint8* tr7, const quint8* bl7, const quint8* br7, const int16* xlp7,
210 const int16x8_t ys_lerpsx) {
211 const uint8x8_t tl8x8 = ToUint8x8(tl0, tl1, tl2, tl3, tl4, tl5, tl6, tl7);
212 const uint8x8_t tr8x8 = ToUint8x8(tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7);
213 const uint8x8_t bl8x8 = ToUint8x8(bl0, bl1, bl2, bl3, bl4, bl5, bl6, bl7);
214 const uint8x8_t br8x8 = ToUint8x8(br0, br1, br2, br3, br4, br5, br6, br7);
215 const int16x8_t xs_lerpsx =
216 ToInt16x8(xlp0, xlp1, xlp2, xlp3, xlp4, xlp5, xlp6, xlp7);
217 return ComputeLerp8x8<RESOLUTION>(tl8x8, tr8x8, bl8x8, br8x8, xs_lerpsx,
218 ys_lerpsx);
219 }
220
221 // Expand address at compile time to improve performance
222 template <int RESOLUTION, int ID0, int CH0, int ID1, int CH1, int ID2, int CH2,
223 int ID3, int CH3, int ID4, int CH4, int ID5, int CH5, int ID6,
224 int CH6, int ID7, int CH7>
ComputeLerpx8Tmpl(const quint8 * const yl,const quint8 * yu,const int64 * xl,const int64 * xu,const int16 * xlp,const int16x8_t ys_lerpsx)225 inline uint8x8_t ComputeLerpx8Tmpl(const quint8* const yl, const quint8* yu,
226 const int64* xl, const int64* xu,
227 const int16* xlp,
228 const int16x8_t ys_lerpsx) {
229 return ComputeLerpx8<RESOLUTION>(
230 yl + xl[ID0] + CH0, yl + xu[ID0] + CH0, yu + xl[ID0] + CH0,
231 yu + xu[ID0] + CH0, xlp + ID0, yl + xl[ID1] + CH1, yl + xu[ID1] + CH1,
232 yu + xl[ID1] + CH1, yu + xu[ID1] + CH1, xlp + ID1, yl + xl[ID2] + CH2,
233 yl + xu[ID2] + CH2, yu + xl[ID2] + CH2, yu + xu[ID2] + CH2, xlp + ID2,
234 yl + xl[ID3] + CH3, yl + xu[ID3] + CH3, yu + xl[ID3] + CH3,
235 yu + xu[ID3] + CH3, xlp + ID3, yl + xl[ID4] + CH4, yl + xu[ID4] + CH4,
236 yu + xl[ID4] + CH4, yu + xu[ID4] + CH4, xlp + ID4, yl + xl[ID5] + CH5,
237 yl + xu[ID5] + CH5, yu + xl[ID5] + CH5, yu + xu[ID5] + CH5, xlp + ID5,
238 yl + xl[ID6] + CH6, yl + xu[ID6] + CH6, yu + xl[ID6] + CH6,
239 yu + xu[ID6] + CH6, xlp + ID6, yl + xl[ID7] + CH7, yl + xu[ID7] + CH7,
240 yu + xl[ID7] + CH7, yu + xu[ID7] + CH7, xlp + ID7, ys_lerpsx);
241 }
242
243 #endif
244
245 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
OutputLerpForChannels(const InterpolationCache<T_SCALE> & xs,const int64 x,const T_SCALE ys_ilerp,const int channels,const float min,const float max,const T * ys_input_lower_ptr,const T * ys_input_upper_ptr,T * output_y_ptr)246 inline void OutputLerpForChannels(const InterpolationCache<T_SCALE>& xs,
247 const int64 x, const T_SCALE ys_ilerp,
248 const int channels, const float min,
249 const float max, const T* ys_input_lower_ptr,
250 const T* ys_input_upper_ptr,
251 T* output_y_ptr) {
252 const int64 xs_lower = xs.lower[x];
253 const int64 xs_upper = xs.upper[x];
254 const T_SCALE xs_ilerp = xs.ilerp[x];
255 for (int c = 0; c < channels; ++c) {
256 const T top_left = ys_input_lower_ptr[xs_lower + c];
257 const T top_right = ys_input_lower_ptr[xs_upper + c];
258 const T bottom_left = ys_input_upper_ptr[xs_lower + c];
259 const T bottom_right = ys_input_upper_ptr[xs_upper + c];
260 const T val = ComputeLerp<RESOLUTION, T, T_SCALE, T_CALC>(
261 top_left, top_right, bottom_left, bottom_right, xs_ilerp, ys_ilerp);
262 output_y_ptr[x * channels + c] = val;
263 }
264 }
265
266 template <int RES>
OutputLerp8x8x1(const InterpolationCache<int16> & xs,const int64 x_start,const int16 ys_ilerp,const float min,const float max,const quint8 * const ys_input_lower_ptr,const quint8 * const ys_input_upper_ptr,quint8 * output_y_ptr)267 inline void OutputLerp8x8x1(const InterpolationCache<int16>& xs,
268 const int64 x_start, const int16 ys_ilerp,
269 const float min, const float max,
270 const quint8* const ys_input_lower_ptr,
271 const quint8* const ys_input_upper_ptr,
272 quint8* output_y_ptr) {
273 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
274 const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
275
276 const uint8x8_t x0x7 =
277 ComputeLerpx8Tmpl<RES, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0>(
278 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
279 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
280
281 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start), x0x7);
282
283 #else
284 for (int x = x_start; x < x_start + 8; ++x) {
285 OutputLerpForChannels<RES, quint8, int16, int16>(
286 xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
287 output_y_ptr);
288 }
289 #endif
290 }
291
292 template <int RES>
OutputLerp8x8x3(const InterpolationCache<int16> & xs,const int64 x_start,const int16 ys_ilerp,const float min,const float max,const quint8 * const ys_input_lower_ptr,const quint8 * const ys_input_upper_ptr,quint8 * output_y_ptr)293 inline void OutputLerp8x8x3(const InterpolationCache<int16>& xs,
294 const int64 x_start, const int16 ys_ilerp,
295 const float min, const float max,
296 const quint8* const ys_input_lower_ptr,
297 const quint8* const ys_input_upper_ptr,
298 quint8* output_y_ptr) {
299 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
300 const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
301
302 const uint8x8_t x0c0x2c1 =
303 ComputeLerpx8Tmpl<RES, 0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1>(
304 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
305 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
306
307 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3), x0c0x2c1);
308
309 const uint8x8_t x2c2x5c0 =
310 ComputeLerpx8Tmpl<RES, 2, 2, 3, 0, 3, 1, 3, 2, 4, 0, 4, 1, 4, 2, 5, 0>(
311 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
312 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
313
314 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 8), x2c2x5c0);
315
316 const uint8x8_t x5c1x7c2 =
317 ComputeLerpx8Tmpl<RES, 5, 1, 5, 2, 6, 0, 6, 1, 6, 2, 7, 0, 7, 1, 7, 2>(
318 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
319 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
320
321 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 16),
322 x5c1x7c2);
323
324 #else
325 for (int x = x_start; x < x_start + 8; ++x) {
326 OutputLerpForChannels<RES, quint8, int16, int16>(
327 xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
328 output_y_ptr);
329 }
330 #endif
331 }
332
333 template <int RESOLUTION>
OutputLerp32x4x1(const InterpolationCache<int32> & xs,const int64 x_start,const int32 ys_ilerp,const float min,const float max,const qint32 * const ys_input_lower_ptr,const qint32 * const ys_input_upper_ptr,qint32 * output_y_ptr)334 inline void OutputLerp32x4x1(const InterpolationCache<int32>& xs,
335 const int64 x_start, const int32 ys_ilerp,
336 const float min, const float max,
337 const qint32* const ys_input_lower_ptr,
338 const qint32* const ys_input_upper_ptr,
339 qint32* output_y_ptr) {
340 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
341 const int64 xs_lower0 = xs.lower[x_start];
342 const int64 xs_upper0 = xs.upper[x_start];
343 const int32* const xs_ilerp0 = &xs.ilerp[x_start];
344 const int64 xs_lower1 = xs.lower[x_start + 1];
345 const int64 xs_upper1 = xs.upper[x_start + 1];
346 const int64 xs_lower2 = xs.lower[x_start + 2];
347 const int64 xs_upper2 = xs.upper[x_start + 2];
348 const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
349 const int64 xs_lower3 = xs.lower[x_start + 3];
350 const int64 xs_upper3 = xs.upper[x_start + 3];
351
352 const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
353
354 const int32x2_t x0x1 = ComputeLerpx2<RESOLUTION, false>(
355 ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
356 ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
357 ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
358 ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
359 y_lerpsx);
360
361 const int32x2_t x1x2 = ComputeLerpx2<RESOLUTION, false>(
362 ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
363 ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
364 ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
365 ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
366 y_lerpsx);
367
368 const int32x4_t x0x1x2x3 = vcombine_s32(x0x1, x1x2);
369
370 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start), x0x1x2x3);
371
372 #else
373 for (int x = x_start; x < x_start + 4; ++x) {
374 OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
375 xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
376 output_y_ptr);
377 }
378 #endif
379 }
380
381 template <int RESOLUTION>
OutputLerp32x4x3(const InterpolationCache<int32> & xs,const int64 x_start,const int32 ys_ilerp,const float min,const float max,const qint32 * const ys_input_lower_ptr,const qint32 * const ys_input_upper_ptr,qint32 * output_y_ptr)382 inline void OutputLerp32x4x3(const InterpolationCache<int32>& xs,
383 const int64 x_start, const int32 ys_ilerp,
384 const float min, const float max,
385 const qint32* const ys_input_lower_ptr,
386 const qint32* const ys_input_upper_ptr,
387 qint32* output_y_ptr) {
388 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
389 const int64 xs_lower0 = xs.lower[x_start];
390 const int64 xs_upper0 = xs.upper[x_start];
391 const int32* const xs_ilerp0 = &xs.ilerp[x_start];
392 const int64 xs_lower1 = xs.lower[x_start + 1];
393 const int64 xs_upper1 = xs.upper[x_start + 1];
394 const int32* const xs_ilerp1 = &xs.ilerp[x_start + 1];
395 const int64 xs_lower2 = xs.lower[x_start + 2];
396 const int64 xs_upper2 = xs.upper[x_start + 2];
397 const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
398 const int64 xs_lower3 = xs.lower[x_start + 3];
399 const int64 xs_upper3 = xs.upper[x_start + 3];
400 const int32* const xs_ilerp3 = &xs.ilerp[x_start + 3];
401
402 const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
403
404 const int32x2_t x0c0x0c1 = ComputeLerpx2<RESOLUTION, true>(
405 ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
406 ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
407 ys_input_lower_ptr + xs_lower0 + 1, ys_input_lower_ptr + xs_upper0 + 1,
408 ys_input_upper_ptr + xs_lower0 + 1, ys_input_upper_ptr + xs_upper0 + 1,
409 xs_ilerp0, y_lerpsx);
410
411 const int32x2_t x0c2x1c0 = ComputeLerpx2<RESOLUTION, false>(
412 ys_input_lower_ptr + xs_lower0 + 2, ys_input_lower_ptr + xs_upper0 + 2,
413 ys_input_upper_ptr + xs_lower0 + 2, ys_input_upper_ptr + xs_upper0 + 2,
414 ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
415 ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
416 y_lerpsx);
417
418 const int32x2_t x1c1x1c2 = ComputeLerpx2<RESOLUTION, true>(
419 ys_input_lower_ptr + xs_lower1 + 1, ys_input_lower_ptr + xs_upper1 + 1,
420 ys_input_upper_ptr + xs_lower1 + 1, ys_input_upper_ptr + xs_upper1 + 1,
421 ys_input_lower_ptr + xs_lower1 + 2, ys_input_lower_ptr + xs_upper1 + 2,
422 ys_input_upper_ptr + xs_lower1 + 2, ys_input_upper_ptr + xs_upper1 + 2,
423 xs_ilerp1, y_lerpsx);
424
425 const int32x2_t x2c0x2c1 = ComputeLerpx2<RESOLUTION, true>(
426 ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
427 ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
428 ys_input_lower_ptr + xs_lower2 + 1, ys_input_lower_ptr + xs_upper2 + 1,
429 ys_input_upper_ptr + xs_lower2 + 1, ys_input_upper_ptr + xs_upper2 + 1,
430 xs_ilerp2, y_lerpsx);
431
432 const int32x2_t x2c2x3c0 = ComputeLerpx2<RESOLUTION, false>(
433 ys_input_lower_ptr + xs_lower2 + 2, ys_input_lower_ptr + xs_upper2 + 2,
434 ys_input_upper_ptr + xs_lower2 + 2, ys_input_upper_ptr + xs_upper2 + 2,
435 ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
436 ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
437 y_lerpsx);
438
439 const int32x2_t x3c1x3c2 = ComputeLerpx2<RESOLUTION, true>(
440 ys_input_lower_ptr + xs_lower3 + 1, ys_input_lower_ptr + xs_upper3 + 1,
441 ys_input_upper_ptr + xs_lower3 + 1, ys_input_upper_ptr + xs_upper3 + 1,
442 ys_input_lower_ptr + xs_lower3 + 2, ys_input_lower_ptr + xs_upper3 + 2,
443 ys_input_upper_ptr + xs_lower3 + 2, ys_input_upper_ptr + xs_upper3 + 2,
444 xs_ilerp3, y_lerpsx);
445
446 const int32x4_t x0c0x0c1x0c2x1c0 = vcombine_s32(x0c0x0c1, x0c2x1c0);
447 const int32x4_t x1c1x1c2x2c0x2c1 = vcombine_s32(x1c1x1c2, x2c0x2c1);
448 const int32x4_t x2c2x3c0x3c1x3c2 = vcombine_s32(x2c2x3c0, x3c1x3c2);
449
450 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3),
451 x0c0x0c1x0c2x1c0);
452 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 4),
453 x1c1x1c2x2c0x2c1);
454 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 8),
455 x2c2x3c0x3c1x3c2);
456
457 #else
458 for (int x = x_start; x < x_start + 4; ++x) {
459 OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
460 xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
461 output_y_ptr);
462 }
463 #endif
464 }
465
466 template <typename T>
ResizeImageReference(typename TTypes<T,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)467 void ResizeImageReference(typename TTypes<T, 4>::ConstTensor images,
468 const int batch_size, const int64 in_height,
469 const int64 in_width, const int64 out_height,
470 const int64 out_width, const int channels,
471 const float height_scale, const float width_scale,
472 const float in_min, const float in_max,
473 const bool half_pixel_centers,
474 typename TTypes<T, 4>::Tensor* output) {
475 CHECK_NOTNULL(output);
476
477 const InterpolationCache<float> xs = BuildLerpCache<float>(
478 out_width, in_width, width_scale, channels, 0, half_pixel_centers);
479 const InterpolationCache<float> ys = BuildLerpCache<float>(
480 out_height, in_height, height_scale, 1, 0, half_pixel_centers);
481
482 const int64 in_row_size = in_width * channels;
483 const int64 in_batch_num_values = in_height * in_row_size;
484 const int64 out_row_size = out_width * channels;
485
486 const T* input_b_ptr = images.data();
487
488 T* output_y_ptr = output->data();
489 for (int b = 0; b < batch_size; ++b) {
490 for (int64 y = 0; y < out_height; ++y) {
491 const T* ys_input_lower_ptr = input_b_ptr + ys.lower[y] * in_row_size;
492 const T* ys_input_upper_ptr = input_b_ptr + ys.upper[y] * in_row_size;
493 const float ys_lerp = ys.lerp[y];
494 for (int64 x = 0; x < out_width; ++x) {
495 const int64 xs_lower = xs.lower[x];
496 const int64 xs_upper = xs.upper[x];
497 const float xs_lerp = xs.lerp[x];
498 for (int c = 0; c < channels; ++c) {
499 const T top_left = ys_input_lower_ptr[xs_lower + c];
500 const T top_right = ys_input_lower_ptr[xs_upper + c];
501 const T bottom_left = ys_input_upper_ptr[xs_lower + c];
502 const T bottom_right = ys_input_upper_ptr[xs_upper + c];
503 const T val = ComputeLerpReference<T>(
504 top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp,
505 in_min, in_max);
506 output_y_ptr[x * channels + c] = val;
507 }
508 }
509 output_y_ptr += out_row_size;
510 }
511 input_b_ptr += in_batch_num_values;
512 }
513 }
514
515 template <typename T>
ResizeImage(typename TTypes<T,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)516 void ResizeImage(typename TTypes<T, 4>::ConstTensor images,
517 const int batch_size, const int64 in_height,
518 const int64 in_width, const int64 out_height,
519 const int64 out_width, const int channels,
520 const float height_scale, const float width_scale,
521 const float in_min, const float in_max,
522 const bool half_pixel_centers,
523 typename TTypes<T, 4>::Tensor* output) {
524 ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
525 out_width, channels, height_scale, width_scale,
526 in_min, in_max, half_pixel_centers, output);
527 }
528
529 template <>
ResizeImage(typename TTypes<qint32,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<qint32,4>::Tensor * output)530 void ResizeImage<qint32>(typename TTypes<qint32, 4>::ConstTensor images,
531 const int batch_size, const int64 in_height,
532 const int64 in_width, const int64 out_height,
533 const int64 out_width, const int channels,
534 const float height_scale, const float width_scale,
535 const float in_min, const float in_max,
536 const bool half_pixel_centers,
537 typename TTypes<qint32, 4>::Tensor* output) {
538 // 30 is maximum resolution for signed int.
539 constexpr int RESOLUTION = 30;
540 constexpr int SIMD_STEP = 4;
541
542 CHECK_NOTNULL(output);
543
544 const InterpolationCache<int32> xs =
545 BuildLerpCache<int32>(out_width, in_width, width_scale, channels,
546 RESOLUTION, half_pixel_centers);
547 const InterpolationCache<int32> ys = BuildLerpCache<int32>(
548 out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
549
550 const int64 in_row_size = in_width * channels;
551 const int64 in_batch_num_values = in_height * in_row_size;
552 const int64 out_row_size = out_width * channels;
553
554 const qint32* input_b_ptr = images.data();
555
556 qint32* output_y_ptr = output->data();
557
558 for (int b = 0; b < batch_size; ++b) {
559 for (int64 y = 0; y < out_height; ++y) {
560 const qint32* ys_input_lower_ptr =
561 input_b_ptr + ys.lower[y] * in_row_size;
562 const qint32* ys_input_upper_ptr =
563 input_b_ptr + ys.upper[y] * in_row_size;
564 const int32 ys_ilerp = ys.ilerp[y];
565 // Optimized for channels == 1 or channels == 3 as this
566 // is typical channels.
567 int64 x = 0;
568 if (channels == 1) {
569 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
570 OutputLerp32x4x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
571 ys_input_lower_ptr, ys_input_upper_ptr,
572 output_y_ptr);
573 }
574 } else if (channels == 3) {
575 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
576 OutputLerp32x4x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
577 ys_input_lower_ptr, ys_input_upper_ptr,
578 output_y_ptr);
579 }
580 }
581 for (; x < out_width; ++x) {
582 OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
583 xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
584 ys_input_upper_ptr, output_y_ptr);
585 }
586 output_y_ptr += out_row_size;
587 }
588 input_b_ptr += in_batch_num_values;
589 }
590 }
591
592 template <>
ResizeImage(typename TTypes<quint8,4>::ConstTensor images,const int batch_size,const int64 in_height,const int64 in_width,const int64 out_height,const int64 out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<quint8,4>::Tensor * output)593 void ResizeImage<quint8>(typename TTypes<quint8, 4>::ConstTensor images,
594 const int batch_size, const int64 in_height,
595 const int64 in_width, const int64 out_height,
596 const int64 out_width, const int channels,
597 const float height_scale, const float width_scale,
598 const float in_min, const float in_max,
599 const bool half_pixel_centers,
600 typename TTypes<quint8, 4>::Tensor* output) {
601 // 7 is maximum resolution for unsigned byte.
602 constexpr int RESOLUTION = 7;
603 constexpr int SIMD_STEP = 8;
604
605 CHECK_NOTNULL(output);
606
607 const InterpolationCache<int16> xs =
608 BuildLerpCache<int16>(out_width, in_width, width_scale, channels,
609 RESOLUTION, half_pixel_centers);
610 const InterpolationCache<int16> ys = BuildLerpCache<int16>(
611 out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
612
613 const int64 in_row_size = in_width * channels;
614 const int64 in_batch_num_values = in_height * in_row_size;
615 const int64 out_row_size = out_width * channels;
616
617 const quint8* input_b_ptr = images.data();
618
619 quint8* output_y_ptr = output->data();
620
621 for (int b = 0; b < batch_size; ++b) {
622 for (int64 y = 0; y < out_height; ++y) {
623 const quint8* ys_input_lower_ptr =
624 input_b_ptr + ys.lower[y] * in_row_size;
625 const quint8* ys_input_upper_ptr =
626 input_b_ptr + ys.upper[y] * in_row_size;
627 const int32 ys_ilerp = ys.ilerp[y];
628 // Optimized for channels == 1 or channels == 3 as this
629 // is typical channels.
630 // TODO(satok): Support more generic NEON optimized implementation
631 // for different channels.
632 int64 x = 0;
633 if (channels == 1) {
634 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
635 OutputLerp8x8x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
636 ys_input_lower_ptr, ys_input_upper_ptr,
637 output_y_ptr);
638 }
639 } else if (channels == 3) {
640 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
641 OutputLerp8x8x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
642 ys_input_lower_ptr, ys_input_upper_ptr,
643 output_y_ptr);
644 }
645 }
646 for (; x < out_width; ++x) {
647 OutputLerpForChannels<RESOLUTION, quint8, int16, int16>(
648 xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
649 ys_input_upper_ptr, output_y_ptr);
650 }
651 output_y_ptr += out_row_size;
652 }
653 input_b_ptr += in_batch_num_values;
654 }
655 }
656
657 template <typename T>
ResizeBilinear(const typename TTypes<T,4>::ConstTensor & images,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)658 void ResizeBilinear(const typename TTypes<T, 4>::ConstTensor& images,
659 const float height_scale, const float width_scale,
660 const float in_min, const float in_max,
661 const bool half_pixel_centers,
662 typename TTypes<T, 4>::Tensor* output) {
663 CHECK_NOTNULL(output);
664
665 const int batch_size = images.dimension(0);
666 const int64 in_height = images.dimension(1);
667 const int64 in_width = images.dimension(2);
668 const int channels = images.dimension(3);
669
670 const int64 out_height = output->dimension(1);
671 const int64 out_width = output->dimension(2);
672
673 // Handle no-op resizes efficiently.
674 if (out_height == in_height && out_width == in_width) {
675 *output = images.template cast<T>();
676 return;
677 }
678
679 if (USE_REFERENCE) {
680 ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
681 out_width, channels, height_scale, width_scale,
682 in_min, in_max, half_pixel_centers, output);
683 } else {
684 ResizeImage<T>(images, batch_size, in_height, in_width, out_height,
685 out_width, channels, height_scale, width_scale, in_min,
686 in_max, half_pixel_centers, output);
687 }
688 }
689
690 } // namespace
691
692 template <class T>
693 class QuantizedResizeBilinearOp : public OpKernel {
694 public:
QuantizedResizeBilinearOp(OpKernelConstruction * context)695 explicit QuantizedResizeBilinearOp(OpKernelConstruction* context)
696 : OpKernel(context) {
697 OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
698 OP_REQUIRES_OK(
699 context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
700 }
701
Compute(OpKernelContext * context)702 void Compute(OpKernelContext* context) override {
703 const Tensor& input = context->input(0);
704 const float in_min = context->input(2).flat<float>()(0);
705 const float in_max = context->input(3).flat<float>()(0);
706
707 ImageResizerState st(align_corners_, false);
708 st.ValidateAndCreateOutput(context, input);
709
710 if (!context->status().ok()) return;
711
712 // Return if the output is empty.
713 if (st.output->NumElements() == 0) return;
714
715 typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>());
716 typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
717
718 ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min,
719 in_max, half_pixel_centers_, &output_data);
720 Tensor* out_min = nullptr;
721 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &out_min));
722 out_min->flat<float>()(0) = in_min;
723
724 Tensor* out_max = nullptr;
725 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &out_max));
726 out_max->flat<float>()(0) = in_max;
727 }
728
729 private:
730 bool align_corners_;
731 bool half_pixel_centers_;
732
733 TF_DISALLOW_COPY_AND_ASSIGN(QuantizedResizeBilinearOp<T>);
734 };
735
736 #define REGISTER_CPU_KERNEL(type) \
737 REGISTER_KERNEL_BUILDER(Name("QuantizedResizeBilinear") \
738 .Device(DEVICE_CPU) \
739 .HostMemory("size") \
740 .TypeConstraint<type>("T"), \
741 QuantizedResizeBilinearOp<type>)
742
743 REGISTER_CPU_KERNEL(::tensorflow::quint8);
744 REGISTER_CPU_KERNEL(::tensorflow::qint32);
745 REGISTER_CPU_KERNEL(float);
746
747 } // namespace tensorflow
748