1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
19 #define USE_NEON
20 #include <arm_neon.h>
21 #endif
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28
29 #include "tensorflow/core/kernels/quantization_utils.h"
30
31 #ifdef USE_NEON
32 namespace {
33
34 // Single pass mean and variance.
35 // Shape of `input` is [rows x cols], shape of both `mean` and `variance`
36 // is [cols].
37 // Note, `mean` and `variance` are of 'i' (not scaled).
38 // The following is a straightforward implementation of the parallel algorithm
39 // described in
40 // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
ColMeanAndVariance(const uint8_t * input,const uint32_t rows,const uint32_t cols,float * mean,float * variance)41 void ColMeanAndVariance(const uint8_t* input, const uint32_t rows,
42 const uint32_t cols, float* mean, float* variance) {
43 // The implementation operates on for 16 columns at a time.
44 // Assumes cols % 16 == 0
45 for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
46 // Vector registers to track the running sum across the rows. Since there
47 // are 16 columns, we have 4 32x4 registers.
48 uint32x4_t sum[4] = {0};
49
50 float nA = 0.0f;
51 // Running average and the second moment.
52 float32x4_t xA[4] = {0.0f};
53 float32x4_t M2A[4] = {0.0f};
54
55 const uint8_t* inp_ptr = input + col_offset;
56 // Go over the rows in chunks of 256. This is so that we can use 16 bit adds
57 // to do the accumulation.
58 for (uint32_t row = 0; row < rows; row += 256) {
59 // Running sum and sum of squares for the 256 rows.
60 uint32x4_t sub_sum[4] = {0};
61 uint32x4_t sub_sq_sum[4] = {0};
62 const uint32_t limit = std::min(rows, row + 256);
63 const float nB = limit - row;
64 for (uint32_t subrow = row; subrow < limit; ++subrow) {
65 const uint8x16_t v = vld1q_u8(inp_ptr);
66 inp_ptr += cols;
67
68 const uint8x8_t v_high = vget_high_u8(v);
69 const uint8x8_t v_low = vget_low_u8(v);
70
71 const uint16x8_t v_high_u16 = vmovl_u8(v_high);
72 const uint16x8_t v_low_u16 = vmovl_u8(v_low);
73
74 const uint16x4_t v_high_high = vget_high_u16(v_high_u16);
75 const uint16x4_t v_high_low = vget_low_u16(v_high_u16);
76 const uint16x4_t v_low_high = vget_high_u16(v_low_u16);
77 const uint16x4_t v_low_low = vget_low_u16(v_low_u16);
78
79 sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high);
80 sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low);
81 sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high);
82 sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low);
83
84 sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high);
85 sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low);
86 sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high);
87 sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low);
88 }
89
90 // Update the full running sum and moment from the ones for 256 rows.
91 for (int i = 0; i < 4; ++i) {
92 sum[i] = vaddq_u32(sum[i], sub_sum[i]);
93 const float nX = nA + nB;
94 // xB is the average of up to 256 elements.
95 const float32x4_t xB =
96 vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB);
97
98 // delta = xB - xA
99 const float32x4_t delta = vsubq_f32(xB, xA[i]);
100 // xA = (nA * xA + nB * xB) / (nA + nB)
101 xA[i] = vmulq_n_f32(
102 vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX);
103
104 const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]);
105 const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32);
106
107 // M2B = sum(xB^2) - sum(xB)^2/nB
108 const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]),
109 vmulq_n_f32(sub_sum_sq, 1.0f / nB));
110 const float32x4_t last_term =
111 vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX);
112 // M2A = oldM2A + M2B + delta^2 * nA*nB/nX
113 M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term);
114 }
115 nA += limit;
116 }
117
118 // Write the final mean and variance for the 16 columns.
119 const float inv_rows = 1.0f / static_cast<float>(rows);
120 vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows));
121 vst1q_f32(mean + col_offset + 4,
122 vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows));
123 vst1q_f32(mean + col_offset + 8,
124 vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows));
125 vst1q_f32(mean + col_offset + 12,
126 vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows));
127
128 vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows));
129 vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows));
130 vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows));
131 vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows));
132 }
133 }
134
135 // Compute min and max of (input - mean) / sqrt(variance + epsilon).
136 // This is done in a separate pass so that the normalized value can be
137 // temporarily computed in floating point precision and not stored anywhere.
MinAndMax(const uint8_t * input,const uint32_t rows,const uint32_t cols,const float * mean_ptr,const float * variance_ptr,float variance_epsilon,float * minimum,float * maximum)138 void MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols,
139 const float* mean_ptr, const float* variance_ptr,
140 float variance_epsilon, float* minimum, float* maximum) {
141 float v_maximum = std::numeric_limits<float>::min();
142 float v_minimum = std::numeric_limits<float>::max();
143 const float32x4_t eps = vdupq_n_f32(variance_epsilon);
144
145 for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
146 const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset),
147 vld1q_f32(mean_ptr + col_offset + 4),
148 vld1q_f32(mean_ptr + col_offset + 8),
149 vld1q_f32(mean_ptr + col_offset + 12)};
150 const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset),
151 vld1q_f32(variance_ptr + col_offset + 4),
152 vld1q_f32(variance_ptr + col_offset + 8),
153 vld1q_f32(variance_ptr + col_offset + 12)};
154 const float32x4_t inv_stddev[4] = {
155 vrsqrteq_f32(vaddq_f32(variance[0], eps)),
156 vrsqrteq_f32(vaddq_f32(variance[1], eps)),
157 vrsqrteq_f32(vaddq_f32(variance[2], eps)),
158 vrsqrteq_f32(vaddq_f32(variance[3], eps))};
159
160 const uint8_t* inp_ptr = input + col_offset;
161 for (uint32_t row = 0; row < rows; ++row) {
162 const uint8x16_t v = vld1q_u8(inp_ptr);
163 inp_ptr += cols;
164
165 const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
166 const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
167
168 const float32x4_t v_float[4] = {
169 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
170 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
171 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
172 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
173
174 for (int i = 0; i < 4; ++i) {
175 const float32x4_t normed =
176 vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
177 const float32x2_t high = vget_high_f32(normed);
178 const float32x2_t low = vget_low_f32(normed);
179 float32x2_t tmp_max = vpmax_f32(low, high);
180 tmp_max = vpmax_f32(tmp_max, tmp_max);
181 v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0));
182 float32x2_t tmp_min = vpmin_f32(low, high);
183 tmp_min = vpmin_f32(tmp_min, tmp_min);
184 v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0));
185 }
186 }
187 }
188 *minimum = v_minimum;
189 *maximum = v_maximum;
190 }
191
192 // Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize
193 // it in the range (minimum, maximum) and store the result as quint8.
InstanceNorm(const uint8_t * input,const uint32_t rows,const uint32_t cols,const float * mean_ptr,const float * variance_ptr,float variance_epsilon,float minimum,float maximum,uint8_t * output)194 void InstanceNorm(const uint8_t* input, const uint32_t rows,
195 const uint32_t cols, const float* mean_ptr,
196 const float* variance_ptr, float variance_epsilon,
197 float minimum, float maximum, uint8_t* output) {
198 const float32x4_t eps = vdupq_n_f32(variance_epsilon);
199 const float32x4_t out_min = vdupq_n_f32(minimum);
200 const float out_scale = 255.0f / (maximum - minimum);
201
202 for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
203 const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12),
204 vld1q_f32(mean_ptr + col_offset + 8),
205 vld1q_f32(mean_ptr + col_offset + 4),
206 vld1q_f32(mean_ptr + col_offset)};
207 const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12),
208 vld1q_f32(variance_ptr + col_offset + 8),
209 vld1q_f32(variance_ptr + col_offset + 4),
210 vld1q_f32(variance_ptr + col_offset)};
211 const float32x4_t inv_stddev[4] = {
212 vrsqrteq_f32(vaddq_f32(variance[0], eps)),
213 vrsqrteq_f32(vaddq_f32(variance[1], eps)),
214 vrsqrteq_f32(vaddq_f32(variance[2], eps)),
215 vrsqrteq_f32(vaddq_f32(variance[3], eps))};
216 const uint8_t* inp_ptr = input + col_offset;
217 uint8_t* out_ptr = output + col_offset;
218 for (uint32_t row = 0; row < rows; ++row) {
219 const uint8x16_t v = vld1q_u8(inp_ptr);
220 inp_ptr += cols;
221 const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
222 const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
223
224 const float32x4_t v_float[4] = {
225 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
226 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
227 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
228 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
229
230 uint16x4_t normed_uint16[4];
231 for (int i = 0; i < 4; ++i) {
232 const float32x4_t normed =
233 vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
234 const int32x4_t normed_int32 =
235 vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale));
236 normed_uint16[i] = vqmovun_s32(normed_int32);
237 }
238 vst1_u8(out_ptr,
239 vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2])));
240 vst1_u8(out_ptr + 8,
241 vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0])));
242 out_ptr += cols;
243 }
244 }
245 }
246
247 } // end namespace
248 #endif // USE_NEON
249
250 namespace tensorflow {
251
252 typedef Eigen::ThreadPoolDevice CPUDevice;
253
254 class QuantizedInstanceNorm : public OpKernel {
255 public:
QuantizedInstanceNorm(OpKernelConstruction * context)256 explicit QuantizedInstanceNorm(OpKernelConstruction* context)
257 : OpKernel(context) {
258 OP_REQUIRES_OK(context,
259 context->GetAttr("variance_epsilon", &variance_epsilon_));
260 OP_REQUIRES_OK(context,
261 context->GetAttr("min_separation", &min_separation_));
262 OP_REQUIRES_OK(
263 context, context->GetAttr("output_range_given", &output_range_given_));
264 if (output_range_given_) {
265 OP_REQUIRES_OK(context, context->GetAttr("given_y_min", &given_y_min_));
266 OP_REQUIRES_OK(context, context->GetAttr("given_y_max", &given_y_max_));
267 OP_REQUIRES(context, given_y_min_ < given_y_max_,
268 errors::InvalidArgument(
269 "given_y_min must be less than given_y_max : ",
270 given_y_min_, " >= ", given_y_max_));
271 }
272 }
273
Compute(OpKernelContext * context)274 void Compute(OpKernelContext* context) override {
275 const Tensor& input = context->input(0);
276
277 float input_min = context->input(1).flat<float>()(0);
278 float input_max = context->input(2).flat<float>()(0);
279 float input_scale = (input_max - input_min) / 255.0f;
280
281 OP_REQUIRES(context, input_min < input_max,
282 errors::InvalidArgument(
283 "input_min must be less than input_max : ", input_min,
284 " >= ", input_max));
285
286 auto input_tensor = input.tensor<quint8, 4>();
287 auto N = input_tensor.dimension(0);
288 auto H = input_tensor.dimension(1);
289 auto W = input_tensor.dimension(2);
290 auto C = input_tensor.dimension(3);
291
292 Tensor* output = nullptr;
293 OP_REQUIRES_OK(context,
294 context->allocate_output(0, input.shape(), &output));
295
296 Tensor* output_min = nullptr;
297 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
298 Tensor* output_max = nullptr;
299 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
300
301 typedef TTypes<float>::Tensor::Index Index;
302
303 #if defined(EIGEN_HAS_INDEX_LIST)
304 const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>>
305 reduction_indices;
306 Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>>
307 broadcast_spec;
308 broadcast_spec.set(1, H);
309 broadcast_spec.set(2, W);
310 Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index>
311 expand_spec;
312 expand_spec.set(0, N);
313 expand_spec.set(3, C);
314 #else
315 const Eigen::array<Index, 2> reduction_indices{1, 2};
316 const Eigen::array<Index, 4> broadcast_spec{1, H, W, 1};
317 const Eigen::array<Index, 4> expand_spec{N, 1, 1, C};
318 #endif
319
320 Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C);
321 Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C);
322
323 #ifdef USE_NEON
324 if (N == 1 && (C % 16 == 0)) {
325 VLOG(2) << "Calling optimized";
326 ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()),
327 H * W, C, float_mean.data(), float_variance.data());
328
329 float minimum = given_y_min_, maximum = given_y_max_;
330 if (!output_range_given_) {
331 MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
332 C, float_mean.data(), float_variance.data(),
333 variance_epsilon_, &minimum, &maximum);
334 }
335
336 if (maximum - minimum < min_separation_) {
337 maximum = minimum + min_separation_;
338 }
339
340 InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
341 C, float_mean.data(), float_variance.data(),
342 variance_epsilon_, minimum, maximum,
343 reinterpret_cast<uint8_t*>(output->flat<quint8>().data()));
344 output_min->scalar<float>()() = minimum;
345 output_max->scalar<float>()() = maximum;
346 } else // NOLINT(readability/braces)
347 #endif
348 {
349 VLOG(2) << "Calling unoptimized";
350 float_mean = input_tensor.cast<float>().reduce(
351 reduction_indices, Eigen::internal::MeanReducer<float>());
352
353 float_variance =
354 (input_scale *
355 ((input_tensor.cast<float>() -
356 float_mean.reshape(expand_spec).broadcast(broadcast_spec))))
357 .square()
358 .reduce(reduction_indices, Eigen::internal::MeanReducer<float>());
359
360 Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed =
361 input_scale *
362 (input_tensor.cast<float>() -
363 float_mean.reshape(expand_spec).broadcast(broadcast_spec)) *
364 (float_variance + variance_epsilon_)
365 .rsqrt()
366 .reshape(expand_spec)
367 .broadcast(broadcast_spec);
368
369 Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min;
370 Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max;
371
372 if (!output_range_given_) {
373 normed_min = instance_normed.minimum();
374 normed_max = instance_normed.maximum();
375 } else {
376 normed_min() = given_y_min_;
377 normed_max() = given_y_max_;
378 }
379
380 if (normed_max() - normed_min() < min_separation_) {
381 normed_max() = normed_min() + min_separation_;
382 }
383
384 FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max());
385 auto instance_normed_quantized =
386 QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8);
387
388 output->tensor<quint8, 4>().device(
389 context->template eigen_device<CPUDevice>()) =
390 instance_normed_quantized;
391 output_min->flat<float>()(0) = normed_min();
392 output_max->flat<float>()(0) = normed_max();
393 }
394 }
395
396 private:
397 float variance_epsilon_;
398 float min_separation_;
399 bool output_range_given_;
400 float given_y_min_;
401 float given_y_max_;
402 };
403
404 REGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm")
405 .Device(DEVICE_CPU)
406 .TypeConstraint<quint8>("T"),
407 QuantizedInstanceNorm);
408
409 } // namespace tensorflow
410