1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <sys/types.h>
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstddef>
20 #include <cstdint>
21 #include <cstdlib>
22 #include <cstring>
23 #include <limits>
24 #include <utility>
25 
26 #include "ruy/ruy.h"  // from @ruy
27 #include "tensorflow/lite/kernels/cpu_backend_context.h"
28 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
29 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/cppmath.h"
33 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
34 #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
35 
36 #ifdef USE_NEON
37 
38 // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
39 #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
40 #if !defined(__ANDROID__) || __ANDROID_API__ >= 28
41 // Neither Apple nor Windows provide aligned_alloc.
42 #if !defined(__APPLE__) && !defined(_WIN32)
43 // TODO(miaowang): Re-enable std::aligned_alloc when it is avalaible in Android.
44 // #define TFLITE_USE_STD_ALIGNED_ALLOC
45 #endif
46 #endif
47 #endif
48 
49 // Note: This is the same as ABSL_HAVE_BUILTIN, but can't include the header.
50 #ifdef __has_builtin
51 #define TFLITE_HAS_BUILTIN(x) __has_builtin(x)
52 #else
53 #define TFLITE_HAS_BUILTIN(x) 0
54 #endif
55 
56 // Note: This is the same as ABSL_PREDICT_FALSE, but can't include the header.
57 #if TFLITE_HAS_BUILTIN(__builtin_expect) || \
58     (defined(__GNUC__) && !defined(__clang__))
59 #define TFLITE_UNLIKELY(x) (__builtin_expect(false || (x), false))
60 #else
61 #define TFLITE_UNLIKELY(x) (x)
62 #endif
63 
64 namespace tflite {
65 namespace tensor_utils {
66 namespace {
67 
68 constexpr int kFloatValuesPerNeonVector = 4;
69 constexpr int kInt16ValuesPerNeonVector = 8;
70 constexpr int kInt8ValuesPerNeonVector = 16;
71 constexpr int kNeonVectorAlignment = 4;
72 
73 template <int PerNeonSize>
RoundDownVectors(int size)74 inline int RoundDownVectors(int size) {
75   return size & ~(PerNeonSize - 1);
76 }
77 
78 // Allocates, at least, size bytes of uninitialized storage whose alignment is
79 // specified by alignment. The size parameter must be an integral multiple of
80 // alignment.
81 // Caller is responsible by freeing the allocated memory by calling free on
82 // the passed freeing_buffer pointer.
aligned_alloc(size_t alignment,size_t size,void ** freeing_buffer)83 inline void* aligned_alloc(size_t alignment, size_t size,
84                            void** freeing_buffer) {
85 #ifdef TFLITE_USE_STD_ALIGNED_ALLOC
86   *freeing_buffer = std::aligned_alloc(
87       alignment, (size + alignment - 1) / alignment * alignment);
88   return *freeing_buffer;
89 #else
90   *freeing_buffer = malloc(size + alignment);
91   const size_t offset = ((uintptr_t)*freeing_buffer) % alignment;  // NOLINT
92   return offset == 0
93              ? *freeing_buffer
94              : ((char*)*freeing_buffer + (alignment - offset));  // NOLINT
95 #endif
96 }
97 
HasSdotInstruction()98 bool HasSdotInstruction() {
99   static const bool has_dotprod = DetectArmNeonDotprod();
100   return has_dotprod;
101 }
102 
AccumulateNeonLane(const float32x4_t lane)103 inline float AccumulateNeonLane(const float32x4_t lane) {
104 #ifdef __aarch64__
105   return vaddvq_f32(lane);
106 #else
107   return vgetq_lane_f32(lane, 0) + vgetq_lane_f32(lane, 1) +
108          vgetq_lane_f32(lane, 2) + vgetq_lane_f32(lane, 3);
109 #endif
110 }
111 
112 // Empirically determined breakpoints on when to use CpuBackendGemm vs.
113 // standard MatrixBatchVectorMultiplyAccumulate. Briefly, if the batch size
114 // is above 8 and the device does not have sdot, use CpuBackendGemm. Otherwise,
115 // for large batch sizes, it makes sense to use CpuBackendGemm if the matrix
116 // is not extremely rectangular.
UseCpuBackendGemm(int rows,int cols,int batch)117 bool UseCpuBackendGemm(int rows, int cols, int batch) {
118   if (!HasSdotInstruction()) {
119     return batch >= 8;
120   }
121   if (batch < 16) {
122     return false;
123   }
124   constexpr int kCpuBackendGemmThreshold = 2;
125   // Calculate "rectangularness" as a measure of how far from square the
126   // the LHS matrix is.
127   int row_rect = rows / cols;
128   int col_rect = cols / rows;
129   int rectangularness_lg2 =
130       row_rect > 0 ? FloorLog2(row_rect) : FloorLog2(col_rect);
131   int batch_lg2 = FloorLog2(batch);
132   // Large batch sizes move us above the threshold, but can be offset
133   // by significant rectangularness.
134   int batch_lg2_minus_rect_lg2 = batch_lg2 - rectangularness_lg2;
135   return batch_lg2_minus_rect_lg2 > kCpuBackendGemmThreshold;
136 }
137 
AccumulateNeonLane(const int32x4_t lane)138 inline int32_t AccumulateNeonLane(const int32x4_t lane) {
139 #ifdef __aarch64__
140   return vaddvq_s32(lane);
141 #else
142   int64x2_t pairwiseAdded = vpaddlq_s32(lane);
143   return vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
144 #endif
145 }
146 
MultiplyByQuantizedMultiplier2Rows(int32x4x2_t input_val,int32 quantized_multiplier,int shift)147 inline int32x4x2_t MultiplyByQuantizedMultiplier2Rows(
148     int32x4x2_t input_val, int32 quantized_multiplier, int shift) {
149   using gemmlowp::RoundingDivideByPOT;
150   using gemmlowp::SaturatingRoundingDoublingHighMul;
151   const int left_shift = shift > 0 ? shift : 0;
152   const int right_shift = shift > 0 ? 0 : -shift;
153   int32x4x2_t result;
154   // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp
155   // is limited to NEON.
156 #ifdef GEMMLOWP_NEON
157   const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift);
158   result.val[0] =
159       RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
160                               vmulq_s32(input_val.val[0], left_shifted_one_dup),
161                               quantized_multiplier),
162                           right_shift);
163   result.val[1] =
164       RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
165                               vmulq_s32(input_val.val[1], left_shifted_one_dup),
166                               quantized_multiplier),
167                           right_shift);
168 #else
169   for (int i = 0; i < 2; ++i) {
170     int32_t vals[4];
171     vals[0] = RoundingDivideByPOT(
172         SaturatingRoundingDoublingHighMul(
173             vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift),
174             quantized_multiplier),
175         right_shift);
176     vals[1] = RoundingDivideByPOT(
177         SaturatingRoundingDoublingHighMul(
178             vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift),
179             quantized_multiplier),
180         right_shift);
181     vals[2] = RoundingDivideByPOT(
182         SaturatingRoundingDoublingHighMul(
183             vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift),
184             quantized_multiplier),
185         right_shift);
186     vals[3] = RoundingDivideByPOT(
187         SaturatingRoundingDoublingHighMul(
188             vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift),
189             quantized_multiplier),
190         right_shift);
191 
192     result.val[i] = vld1q_s32(reinterpret_cast<int32_t*>(&vals));
193   }
194 #endif
195   return result;
196 }
197 
198 }  // namespace
199 
NeonMatrixBatchVectorMultiplyAccumulate(const float * matrix,int m_rows,int m_cols,const float * vector,int n_batch,float * result)200 void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
201                                              int m_cols, const float* vector,
202                                              int n_batch, float* result) {
203   // If v_size is not divisible by the vector size, then we need to process the
204   // final few elements sequentially. postamble_start shows the start index
205   // where this should happen.
206   const int postamble_start =
207       RoundDownVectors<kFloatValuesPerNeonVector>(m_cols);
208 
209   for (int b = 0; b < n_batch; b++) {
210     float* result_in_batch = result + b * m_rows;
211     const float* vector_in_batch = vector + b * m_cols;
212     const float* matrix_row = matrix;
213 
214     // Main matrix by vector multiplication loop
215     for (int r = 0; r < m_rows; r++) {
216       float32x4_t acc_32x4 = vmovq_n_f32(0.0);
217       int c = 0;
218       for (; c < postamble_start; c += kFloatValuesPerNeonVector) {
219         // Load 4 float values from vector and matrix row.
220         float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
221         float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
222         // Multiply the vector and matrix row and add to accumulator.
223         acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
224       }
225       // Add the 4 intermediate sum values to get the final dot-prod value for
226       // this column.
227       *result_in_batch += AccumulateNeonLane(acc_32x4);
228       for (; TFLITE_UNLIKELY(c < m_cols); c++) {
229         *result_in_batch += matrix_row[c] * vector_in_batch[c];
230       }
231       matrix_row += m_cols;
232       ++result_in_batch;
233     }
234   }
235 }
236 
237 #ifdef __aarch64__
238 
239 // We interleave vector data to make the dot product logic more efficient.
240 // Suppose that vectors is:
241 //     a0 a1 a2 a3 a4 a5 ...
242 //     b0 b1 b2 b3 b4 b5 ...
243 //     c0 c1 c2 c3 c4 c5 ...
244 //     d0 d1 d2 d3 d4 d5 ...
245 //     e0 e1 e2 e3 e4 e5 ...
246 // This code interleaves them like this:
247 //     a0 a1 a2 a3 b0 b1 b2 b3 c0 c1 c2 c3 d0 d1 d2 d3 a4 a5 a6 a7 b4 ...
248 //     e0 e1 e2 e3 f0 f1 f2 f3 ...
249 // Once the data is interleaved, each 16-byte read from the vectors pointer
250 // contains 4 bytes from each of 4 vectors.
ShuffleVectors(const int8_t * vectors,const int n_batch,const int m_cols,void ** shuffled_vectors_free)251 const int8_t* ShuffleVectors(const int8_t* vectors, const int n_batch,
252                              const int m_cols, void** shuffled_vectors_free) {
253   int8* shuffled_vectors = reinterpret_cast<int8*>(aligned_alloc(
254       kNeonVectorAlignment, n_batch * m_cols, shuffled_vectors_free));
255 
256   for (int i = 0; i < n_batch; i += 4) {
257     int8* shuffled_vectors_ptr = shuffled_vectors + (i * m_cols);
258     const int8* unshuffled_vec0_ptr =
259         reinterpret_cast<const int8*>(vectors) + (i * m_cols);
260     const int8* unshuffled_vec1_ptr =
261         reinterpret_cast<const int8*>(vectors) + ((i + 1) * m_cols);
262     const int8* unshuffled_vec2_ptr =
263         reinterpret_cast<const int8*>(vectors) + ((i + 2) * m_cols);
264     const int8* unshuffled_vec3_ptr =
265         reinterpret_cast<const int8*>(vectors) + ((i + 3) * m_cols);
266     const int8* const end_vec0_ptr = unshuffled_vec1_ptr;
267 
268     while (unshuffled_vec0_ptr != end_vec0_ptr) {
269       asm volatile(
270           // This code path requires that (n_cols % 16) == 0 so we can safely
271           // read in 16-byte chunks from each row.
272           "ld1 {v0.16b}, [%[unshuffled_vec0_ptr]], #16\n"
273           "ld1 {v1.16b}, [%[unshuffled_vec1_ptr]], #16\n"
274           "ld1 {v2.16b}, [%[unshuffled_vec2_ptr]], #16\n"
275           "ld1 {v3.16b}, [%[unshuffled_vec3_ptr]], #16\n"
276 
277           "st4 {v0.s, v1.s, v2.s, v3.s}[0], [%[shuffled_vectors_ptr]], #16\n"
278           "st4 {v0.s, v1.s, v2.s, v3.s}[1], [%[shuffled_vectors_ptr]], #16\n"
279           "st4 {v0.s, v1.s, v2.s, v3.s}[2], [%[shuffled_vectors_ptr]], #16\n"
280           "st4 {v0.s, v1.s, v2.s, v3.s}[3], [%[shuffled_vectors_ptr]], #16\n"
281 
282           : [unshuffled_vec0_ptr] "+r"(unshuffled_vec0_ptr),
283             [unshuffled_vec1_ptr] "+r"(unshuffled_vec1_ptr),
284             [unshuffled_vec2_ptr] "+r"(unshuffled_vec2_ptr),
285             [unshuffled_vec3_ptr] "+r"(unshuffled_vec3_ptr),
286             [shuffled_vectors_ptr] "+r"(shuffled_vectors_ptr)
287           :
288           : "v0", "v1", "v2", "v3", "cc", "memory");
289     }
290   }
291 
292   return reinterpret_cast<const int8_t*>(shuffled_vectors);
293 }
294 
295 // Notes about the speed of this version vs. the baseline (from memory):
296 // - With 256K of L1, we can keep a lot of vectors in cache.
297 //   I recall a reasonable speedup just by rearranging the loop to have
298 //   row on the outside and batch on the inside.
299 // - I also recall getting a nice speedup from sdot.
300 // - I tried many times to do better than the current implementation, using
301 //   loop unrolling and instruction reordering to avoid stalls, etc.
302 //   but I was not able to do significantly better. This code is, however,
303 //   much worse than what the processor spec sheet suggests is possible.
DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)304 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
305     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
306     const int8_t* vectors, const float* scaling_factors, int n_batch,
307     float* __restrict__ result) {
308   void* shuffled_vectors_free;
309 
310   const int8_t* shuffled_vectors =
311       ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
312 
313   for (int row = 0; row < m_rows; row += 2) {
314     for (int batch = 0; batch < n_batch; batch += 4) {
315       float* result_ptr = result + (batch * m_rows) + row;
316       const int8* mat_ptr0 = matrix + (row * m_cols);
317       const int8* mat_ptr1 = matrix + ((row + 1) * m_cols);
318       const int8* mat_ptr0_end = mat_ptr1;
319       const int8* vec_ptr = shuffled_vectors + (batch * m_cols);
320       const float* scaling_factors_ptr = scaling_factors + batch;
321       const uint64_t wide_rows = m_rows * sizeof(float);
322       const int8* mat_ptr2 = matrix + ((row + 2) * m_cols);
323       const int8* mat_ptr3 = matrix + ((row + 3) * m_cols);
324 
325       asm volatile(
326           // Zero out the accumulator registers.
327           "movi v0.4s, #0\n"
328           "movi v1.4s, #0\n"
329           "movi v2.4s, #0\n"
330           "movi v3.4s, #0\n"
331 
332           "1:\n"  // batch_cols_loop
333 
334           // Read 16 more bytes from a pair of matrix rows.
335           "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
336 
337           // Prefetch two rows ahead.
338           "prfm pldl1strm, [%[mat_ptr2]]\n"
339           "prfm pldl1strm, [%[mat_ptr3]]\n"
340 
341           // Read from input vectors 4 times; 64 bytes total.
342           // Each 16-byte register contains parts of 4 vectors; see the
343           // shuffle logic above.
344 
345           // From Benoit, places to look in the future:
346           // - Move load instructions further from sdot
347           // - Switch loop use-then-reload
348           // - Do partial unrolling to use register space better
349           "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
350           ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
351           "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
352           ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
353           "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
354           ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
355           "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
356           ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
357 
358           // Update prefetch pointers.
359           "add %[mat_ptr2], %[mat_ptr2], #16\n"
360           "add %[mat_ptr3], %[mat_ptr3], #16\n"
361 
362           // Re-use those vectors for the next row as well.
363           "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
364           ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
365           ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
366           ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
367           ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
368 
369           // If we're not done with these rows, continue.
370           "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
371           "bne 1b\n"  // batch_cols_loop
372 
373           // Done with the rows, sum the results.
374           "add v0.4s, v0.4s, v1.4s\n"
375           "add v2.4s, v2.4s, v3.4s\n"
376 
377           // Convert the per-vector sums to floating point.
378           "scvtf v0.4s, v0.4s\n"
379           "scvtf v1.4s, v2.4s\n"
380 
381           // Fetch scale factors.
382           "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
383 
384           // Multiply scale factors times sums.
385           "fmul v0.4s, v4.4s, v0.4s\n"
386           "fmul v1.4s, v4.4s, v1.4s\n"
387 
388           // Load previous result values.
389           // The result position is:
390           //   result[batch * m_rows + row]
391           // Here that is factored into:
392           //   result_ptr = result + row
393           //   *result_ptr = res[0]
394           //   (uint8*)result_ptr += (m_rows * sizeof(float))
395           //   *result_ptr = res[1]
396           //   ...
397           // Since we're reading two rows at a time, though, we read both
398           //   result[batch * m_rows + row]
399           // and
400           //   result[batch * m_rows + row + 1]
401           "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
402           "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
403           "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
404           "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
405 
406           // Go back to the starting position (subtract wide_rows * 4).
407           "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
408 
409           // Add previous result values.
410           "fadd v9.4s, v9.4s, v0.4s\n"
411           "fadd v10.4s, v10.4s, v1.4s\n"
412 
413           // Store results.
414           "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
415           "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
416           "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
417           "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
418           : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1),
419             [vec_ptr] "+r"(vec_ptr), [result_ptr] "+r"(result_ptr),
420             [mat_ptr2] "+r"(mat_ptr2), [mat_ptr3] "+r"(mat_ptr3)
421           : [mat_ptr0_end] "r"(mat_ptr0_end),
422             [scaling_factors_ptr] "r"(scaling_factors_ptr),
423             [wide_rows] "r"(wide_rows)
424           : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
425             "v10", "v11", "v12", "v13", "cc", "memory");
426     }
427   }
428 
429   free(shuffled_vectors_free);
430 }
431 
DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * row_sums)432 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
433     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
434     const int8_t* vectors, const float* scaling_factors, int n_batch,
435     float* __restrict__ result, const float* per_channel_scale,
436     const int32_t* input_offset, int32_t* row_sums) {
437   void* shuffled_vectors_free;
438   const int8_t* shuffled_vectors =
439       ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
440 
441   for (int row = 0; row < m_rows; row += 2) {
442     const float* channel_scales_ptr = per_channel_scale + row;
443     int32_t* row_sums_ptr = row_sums ? row_sums + row : nullptr;
444     for (int batch = 0; batch < n_batch; batch += 4) {
445       float* result_ptr = result + (batch * m_rows) + row;
446       const int8* mat_ptr0 = matrix + (row * m_cols);
447       const int8* mat_ptr1 = matrix + ((row + 1) * m_cols);
448       const int8* mat_ptr0_end = mat_ptr1;
449       const int8* vec_ptr = shuffled_vectors + (batch * m_cols);
450       const float* scaling_factors_ptr = scaling_factors + batch;
451       const uint64_t wide_rows = m_rows * sizeof(float);
452       const int32_t* batch_offsets_ptr = input_offset + batch;
453       const int32_t is_channel_scale_nullptr = per_channel_scale == nullptr;
454       const int32_t is_row_sums_nullptr = row_sums_ptr == nullptr;
455       asm volatile(
456           "movi v0.4s, #0\n"
457           "movi v1.4s, #0\n"
458           "movi v2.4s, #0\n"
459           "movi v3.4s, #0\n"
460           // Load zero points.
461           "ld1 {v7.4s}, [%[batch_offsets_ptr]]\n"
462           "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
463           // Zero out zero point accumulators.
464           "movi v14.4s, #0\n"
465           "movi v15.4s, #0\n"
466 
467           // Load per channel scales if not null.
468           "cmp %w[is_channel_scale_nullptr], #0\n"
469           "bne 1f\n"
470           "ld1r {v16.4s}, [%[channel_scales_ptr]], #4\n"
471           "ld1r {v17.4s}, [%[channel_scales_ptr]]\n"
472           "fmul v16.4s, v16.4s, v4.4s\n"
473           "fmul v17.4s, v17.4s, v4.4s\n"
474           "b 2f\n"
475           "1:\n"
476           "mov v16.16b, v4.16b\n"
477           "mov v17.16b, v4.16b\n"
478           "2:\n"
479           "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
480           "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
481           ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
482           "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
483           ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
484           "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
485           ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
486           "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
487           ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
488           "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
489           ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
490           ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
491           ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
492           ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
493           "cmp %w[is_row_sums_nullptr], #1\n"
494           "bne 3f\n"
495           // Accumulate row_sums for zero point calculations.
496           "saddlp v12.8h, v12.16b\n"
497           "saddlp v13.8h, v13.16b\n"
498           "sadalp v14.4s, v12.8h\n"
499           "sadalp v15.4s, v13.8h\n"
500           "3:\n"
501           "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
502           "bne 2b\n"
503           "add v0.4s, v0.4s, v1.4s\n"
504           "add v2.4s, v2.4s, v3.4s\n"
505 
506           "cmp %w[is_row_sums_nullptr], #1\n"
507           "bne 4f\n"
508           // Calculate zero point offsets.
509           "addv s14, v14.4s\n"
510           "addv s15, v15.4s\n"
511           "dup v14.4s, v14.s[0]\n"
512           "dup v15.4s, v15.s[0]\n"
513           "b 5f\n"
514           "4:\n"
515           "ld1r {v14.4s}, [%[row_sums_ptr]], #4\n"
516           "ld1r {v15.4s}, [%[row_sums_ptr]]\n"
517           "5:\n"
518 
519           "mul v14.4s, v14.4s, v7.4s\n"
520           "mul v15.4s, v15.4s, v7.4s\n"
521           "sub v0.4s, v0.4s, v14.4s\n"
522           "sub v2.4s, v2.4s, v15.4s\n"
523 
524           "scvtf v0.4s, v0.4s\n"
525           "scvtf v1.4s, v2.4s\n"
526 
527           // Multiply scale.
528           "fmul v0.4s, v16.4s, v0.4s\n"
529           "fmul v1.4s, v17.4s, v1.4s\n"
530 
531           "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
532           "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
533           "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
534           "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
535           "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
536           "fadd v9.4s, v9.4s, v0.4s\n"
537           "fadd v10.4s, v10.4s, v1.4s\n"
538           "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
539           "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
540           "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
541           "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
542           : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1),
543             [vec_ptr] "+r"(vec_ptr), [result_ptr] "+r"(result_ptr),
544             [row_sums_ptr] "+r"(row_sums_ptr)
545           : [mat_ptr0_end] "r"(mat_ptr0_end),
546             [scaling_factors_ptr] "r"(scaling_factors_ptr),
547             [wide_rows] "r"(wide_rows),
548             [channel_scales_ptr] "r"(channel_scales_ptr),
549             [batch_offsets_ptr] "r"(batch_offsets_ptr),
550             [is_channel_scale_nullptr] "r"(is_channel_scale_nullptr),
551             [is_row_sums_nullptr] "r"(is_row_sums_nullptr)
552           : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
553             "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "w0", "w1",
554             "cc", "memory");
555     }
556   }
557 
558   free(shuffled_vectors_free);
559 }
560 
DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset)561 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
562     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
563     const int8_t* vectors, const float* scaling_factors, int n_batch,
564     float* __restrict__ result, const float* per_channel_scale,
565     const int32_t* input_offset) {
566   DotprodMatrixBatchFourVectorMultiplyAccumulate(
567       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
568       per_channel_scale, input_offset, nullptr);
569 }
570 
571 // The DotprodMatrixBatchFourVectorMultiplyAccumulate kernel processes 4
572 // vectors in the same time as the baseline processes 1 vector. However, it
573 // requires 4 vectors of input.
574 //
575 // To take advantage of this speed difference, we add some zero-valued
576 // vectors to the batch so that n_batch is a multiple of 4. Then we execute
577 // DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate on that padded batch,
578 // then extract just the results we want at the end (ignoring the extra padding
579 // outputs).
580 //
581 // The relative cost of the padding is large when the matrix is smaller than
582 // 128x128, so we don't use this code path on small matrices. On larger
583 // matrices, the computation cost dwarfs the padding cost, making this code
584 // viable.
585 //
586 // If we ignore the cost of padding, this kernel is:
587 //    1x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 1
588 //    2x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 2
589 //    3x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 3
590 //    ...
591 //
592 // We don't use this kernel when n_batch = 1 because the baseline kernel
593 // is fine for that case.
DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * row_sums)594 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
595     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
596     const int8_t* vectors, const float* scaling_factors, int n_batch,
597     float* __restrict__ result, const float* per_channel_scale,
598     const int32_t* input_offset, int32_t* row_sums) {
599   // Round to the nearest multiple of 4.
600   int batch_round_up = n_batch;
601   if (n_batch % 4 != 0) {
602     batch_round_up += (4 - n_batch % 4);
603   }
604   TFLITE_CHECK_LE(n_batch, batch_round_up);
605 
606   void* padded_vectors_free;
607   const int padded_vectors_size = batch_round_up * m_cols;
608   int8_t* padded_vectors = reinterpret_cast<int8_t*>(aligned_alloc(
609       kNeonVectorAlignment, padded_vectors_size, &padded_vectors_free));
610   memset(padded_vectors, 0, padded_vectors_size);
611 
612   void* padded_result_free;
613   const int result_size = n_batch * m_rows * sizeof(float);
614   const int padded_result_size = batch_round_up * m_rows * sizeof(float);
615   float* padded_result = reinterpret_cast<float*>(aligned_alloc(
616       kNeonVectorAlignment, padded_result_size, &padded_result_free));
617   memcpy(padded_result, result, result_size);
618   memset(reinterpret_cast<char*>(padded_result) + result_size, 0,
619          padded_result_size - result_size);
620 
621   // Copy the input into the padded data structure.
622   TFLITE_CHECK_LE(n_batch * m_cols, padded_vectors_size);
623   memcpy(padded_vectors, vectors, n_batch * m_cols);
624 
625   void* padded_scaling_factors_free;
626   const int padded_scaling_factors_size = batch_round_up * sizeof(float);
627   float* padded_scaling_factors = reinterpret_cast<float*>(
628       aligned_alloc(kNeonVectorAlignment, padded_scaling_factors_size,
629                     &padded_scaling_factors_free));
630   TFLITE_CHECK_LE(n_batch * sizeof(float), padded_scaling_factors_size);
631   TFLITE_CHECK_LE(batch_round_up * sizeof(float), padded_scaling_factors_size);
632   memset(padded_scaling_factors, 0, batch_round_up * sizeof(float));
633   memcpy(padded_scaling_factors, scaling_factors, n_batch * sizeof(float));
634 
635   if (input_offset != nullptr) {
636     void* padded_input_offset_free;
637     const int padded_input_offset_size = batch_round_up * sizeof(int32_t);
638     int32_t* padded_input_offset = reinterpret_cast<int32_t*>(
639         aligned_alloc(kNeonVectorAlignment, padded_input_offset_size,
640                       &padded_input_offset_free));
641     TFLITE_CHECK_LE(n_batch * sizeof(int32_t), padded_input_offset_size);
642     TFLITE_CHECK_LE(batch_round_up * sizeof(int32_t), padded_input_offset_size);
643     memset(padded_input_offset, 0, batch_round_up * sizeof(int32_t));
644     memcpy(padded_input_offset, input_offset, n_batch * sizeof(int32_t));
645 
646     // Call the main kernel.
647     DotprodMatrixBatchFourVectorMultiplyAccumulate(
648         matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors,
649         batch_round_up, padded_result, per_channel_scale, padded_input_offset,
650         row_sums);
651 
652     free(padded_input_offset_free);
653   } else {
654     // Call the main kernel.
655     DotprodMatrixBatchFourVectorMultiplyAccumulate(
656         matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors,
657         batch_round_up, padded_result);
658   }
659   memcpy(result, padded_result, result_size);
660 
661   free(padded_result_free);
662   free(padded_vectors_free);
663   free(padded_scaling_factors_free);
664 }
665 
DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)666 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
667     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
668     const int8_t* vectors, const float* scaling_factors, int n_batch,
669     float* __restrict__ result) {
670   DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
671       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
672       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
673       /*row_sums=*/nullptr);
674 }
675 
DotprodSparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)676 static void DotprodSparseMatrixBatchVectorMultiplyAccumulate(
677     const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
678     const int m_cols, const int8_t* __restrict__ vectors,
679     const float* scaling_factors, int n_batch, float* __restrict__ result) {
680   const uint8_t* ledger_ptr = ledger;
681   const int8* mat_ptr = matrix;
682 
683   for (int row = 0; row < m_rows; row++) {
684     int num_nonzero_chunks = *ledger_ptr;
685     ledger_ptr++;
686     const uint8* ledger_start = ledger_ptr;
687     const uint8* ledger_end = ledger_ptr + num_nonzero_chunks;
688     const int8* mat_start = mat_ptr;
689 
690     for (int batch = 0; batch < n_batch; batch++) {
691       const int8* vec_ptr = vectors + (batch * m_cols);
692       int64_t row_sum = 0;
693 
694       mat_ptr = mat_start;
695       ledger_ptr = ledger_start;
696 
697       if (ledger_ptr != ledger_end) {
698         asm volatile(
699             "movi v0.4s, #0\n"
700             "movi v1.4s, #0\n"
701             "movi v8.4s, #0\n"
702             "mov x7, 0\n"
703 
704             "1:\n"  // chunks_loop
705 
706             // Single matrix chunk, 16 bytes
707             "ld1 {v8.16b}, [%[mat_ptr]], #16\n"
708 
709             // Read the next ledger index and increment.
710             "ldrb w7, [%[ledger_ptr]], #1\n"
711 
712             // Read 16 bytes of vector data from (vec_ptr + (ledger_index * 16))
713             "add x8, %[vec_ptr], x7, lsl #4\n"
714             "ld1 {v9.16b}, [x8]\n"
715 
716             // Dot product of matrix row and vector.
717             ".word 0x4e889520  // sdot v0.4s, v9.16b, v8.16b\n"
718 
719             "cmp %[ledger_ptr], %[ledger_end]\n"
720             "blt 1b\n"  // chunks_loop
721 
722             // Sum the 4 vector components into a 32-bit value.
723             "addv s1, v0.4s\n"
724             // row_sum is 64-bit, so we copy 64 bits of v1 into it.
725             // We have to be careful to cast this value to 32 bits in order
726             // to interpret the sign bit properly.
727             "mov %[row_sum], v1.d[0]\n"
728             : [row_sum] "=r"(row_sum), [ledger_ptr] "+r"(ledger_ptr),
729               [mat_ptr] "+r"(mat_ptr), [vec_ptr] "+r"(vec_ptr)
730             : [ledger_end] "r"(ledger_end)
731             : "x0", "x1", "x7", "x8", "v0", "v1", "v8", "v9", "cc", "memory");
732       }
733       result[batch * m_rows + row] +=
734           static_cast<int32>(row_sum) * scaling_factors[batch];
735     }
736   }
737 }
738 
739 #endif  // __aarch64__
740 
NeonMatrixBatchVectorMultiplyImpl(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch)741 void NeonMatrixBatchVectorMultiplyImpl(const int8_t* input, const int32_t* bias,
742                                        const int8_t* input_to_gate_weights,
743                                        int32_t n_batch, int32_t n_input,
744                                        int32_t n_output, int32_t output_zp,
745                                        int32_t* scratch) {
746   // Assuming *matrix is kNeonVectorAlignment-byte aligned, every row of the
747   // matrix is also kNeonVectorAlignment-byte aligned as long as cols is a
748   // multiple of kNeonVectorAlignment. The assumption is currently satisfied by
749   // TFLite's 16-byte memory alignment scheme.
750   //
751   // Otherwise, we allocate an aligned memory block and set
752   // a flag to later copy rows from matrix to the block
753   // for aligned multiplication.
754   bool unaligned = false;
755   int8_t* aligned_row = nullptr;
756   void* aligned_row_free = nullptr;
757   if ((n_input & (kNeonVectorAlignment - 1)) != 0) {
758     unaligned = true;
759     aligned_row =
760         (int8_t*)aligned_alloc(kNeonVectorAlignment, n_input,  // NOLINT
761                                &aligned_row_free);
762   }
763   void* aligned_vec_free = nullptr;
764   int8_t* aligned_vec =
765       (int8_t*)aligned_alloc(kNeonVectorAlignment, n_input,  // NOLINT
766                              &aligned_vec_free);
767 
768   // If m_cols is not at least kInt8ValuesPerNeonVector, we cannot use the main
769   // vectorized loop, and we need to process sequentially. postamble_half_start
770   // shows the start index where this should happen. Between postamble_start and
771   // postamble_half_start we can still process kInt8ValuesPerNeonVector/2 in a
772   // vectorized form.
773   const int postamble_half_start =
774       RoundDownVectors<kInt8ValuesPerNeonVector>(n_input);
775   const int postamble_start =
776       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(n_input);
777 
778   for (int batch = 0; batch < n_batch; ++batch) {
779     // Copy the vector data to an aligned vector.
780     memcpy(aligned_vec, input + batch * n_input, sizeof(int8_t) * n_input);
781     // Compute dot-product for every column.
782     for (int row = 0; row < n_output; ++row) {
783       // Get the address of the first element of the row.
784       int8_t* row_ptr =
785           (int8_t*)input_to_gate_weights + row * n_input;  // NOLINT
786       if (unaligned) {
787         memcpy(aligned_row, row_ptr, sizeof(int8_t) * n_input);
788         row_ptr = aligned_row;
789       }
790 
791       // Initialize the dot product sum for the row to 0.
792       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
793 
794       // For every block of 16 8-bit elements.
795       int col = 0;
796       for (; col < postamble_half_start; col += kInt8ValuesPerNeonVector) {
797         // Load 16 8-bit values from the row and vector, each, to operate on.
798         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
799         // performance may suffer significantly.
800         TFLITE_DCHECK_EQ(  // NOLINT
801             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
802         const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));
803         const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));
804         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
805         // registers).
806         int16x8_t prod_16x8 =
807             vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
808         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
809         // registers), and accumulate with the result of the low bits product.
810         // The assumption here is that overflow will not happen as we quantize
811         // our values to be in the range [-127, 127]. As such the sum of the 2
812         // products is always strictly smaller than 15-bits (32767 in absolute
813         // value).
814         prod_16x8 =
815             vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
816 
817         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
818       }  // for col
819 
820       // Half iteration dealing only 8 elements
821       if (TFLITE_UNLIKELY(col < postamble_start)) {
822         // Load 8 8-bit values from the row and column each to operate on.
823         // Here the assumption is that each buffer is 4-bytes aligned.
824         // Otherwise, performance may suffer significantly.
825         TFLITE_DCHECK_EQ(  // NOLINT
826             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
827         const int8x8_t s1_8x8 = vld1_s8((const int8_t*)(aligned_vec + col));
828         const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
829         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
830         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
831         col += (kInt8ValuesPerNeonVector >> 1);
832       }
833       // Add the 4 intermediate sum values to get the final dot-prod value for
834       // this row.
835       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
836       // Postamble loop.
837       for (; TFLITE_UNLIKELY(col < n_input); ++col) {
838         dotprod += row_ptr[col] * aligned_vec[col];
839       }  // for col
840 
841       dotprod += bias[row];
842       scratch[batch * n_output + row] = dotprod;
843     }  // for row
844   }    // for batch
845 
846   if (unaligned) {
847     free(aligned_row_free);
848   }
849   free(aligned_vec_free);
850 }
851 
NeonMatrixBatchVectorAccumulateImpl(int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_output,int32_t output_zp,int32_t * scratch,int16_t * output)852 inline void NeonMatrixBatchVectorAccumulateImpl(
853     int32_t multiplier, int32_t shift, int32_t n_batch, int32_t n_output,
854     int32_t output_zp, int32_t* scratch, int16_t* output) {
855   int i = 0;
856   const int total_size = n_batch * n_output;
857 
858   const int32_t output_min = std::numeric_limits<int16_t>::min();
859   const int32_t output_max = std::numeric_limits<int16_t>::max();
860 
861   const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
862   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
863   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
864 
865   using gemmlowp::RoundingDivideByPOT;
866   using gemmlowp::SaturatingRoundingDoublingHighMul;
867 
868   for (; i <= total_size - 8; i += 8) {
869     int32x4x2_t scratch_val;
870     scratch_val.val[0] = vld1q_s32(scratch + i);
871     scratch_val.val[1] = vld1q_s32(scratch + i + 4);
872     const int16x8_t output_val = vld1q_s16(output + i);
873     const int32x4_t first_half = vmovl_s16(vget_low_s16(output_val));
874     const int32x4_t second_half = vmovl_s16(vget_high_s16(output_val));
875 
876     int32x4x2_t temp_val =
877         MultiplyByQuantizedMultiplier2Rows(scratch_val, multiplier, shift);
878 
879     temp_val.val[0] =
880         vaddq_s32(vaddq_s32(temp_val.val[0], first_half), output_zp_dup);
881     temp_val.val[1] =
882         vaddq_s32(vaddq_s32(temp_val.val[1], second_half), output_zp_dup);
883     temp_val.val[0] =
884         vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
885     temp_val.val[1] =
886         vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
887     const int16x8_t result =
888         vcombine_s16(vqmovn_s32(temp_val.val[0]), vqmovn_s32(temp_val.val[1]));
889     vst1q_s16(output + i, result);
890   }
891   for (; TFLITE_UNLIKELY(i < total_size); ++i) {
892     int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
893     temp += output_zp;
894     temp += output[i];
895     if (temp > output_max) {
896       temp = output_max;
897     }
898     if (temp < output_min) {
899       temp = output_min;
900     }
901     output[i] = static_cast<int16_t>(temp);
902   }
903 }
904 
NeonMatrixBatchVectorAccumulateImpl(int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_output,int32_t output_zp,int32_t * scratch,int8_t * output)905 inline void NeonMatrixBatchVectorAccumulateImpl(
906     int32_t multiplier, int32_t shift, int32_t n_batch, int32_t n_output,
907     int32_t output_zp, int32_t* scratch, int8_t* output) {
908   int i = 0;
909   const int total_size = n_batch * n_output;
910 
911   const int32_t output_min = std::numeric_limits<int8_t>::min();
912   const int32_t output_max = std::numeric_limits<int8_t>::max();
913 
914   const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
915   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
916   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
917 
918   using gemmlowp::RoundingDivideByPOT;
919   using gemmlowp::SaturatingRoundingDoublingHighMul;
920 
921   for (; i <= total_size - 16; i += 16) {
922     int32x4x4_t scratch_val;
923     scratch_val.val[0] = vld1q_s32(scratch + i);
924     scratch_val.val[1] = vld1q_s32(scratch + i + 4);
925     scratch_val.val[2] = vld1q_s32(scratch + i + 8);
926     scratch_val.val[3] = vld1q_s32(scratch + i + 12);
927 
928     const int8x16_t output_val = vld1q_s8(output + i);
929     const int16x8_t first_half = vmovl_s8(vget_low_s8(output_val));
930     const int16x8_t second_half = vmovl_s8(vget_high_s8(output_val));
931     const int32x4_t output_val_1 = vmovl_s16(vget_low_s16(first_half));
932     const int32x4_t output_val_2 = vmovl_s16(vget_high_s16(first_half));
933     const int32x4_t output_val_3 = vmovl_s16(vget_low_s16(second_half));
934     const int32x4_t output_val_4 = vmovl_s16(vget_high_s16(second_half));
935 
936     int32x4x4_t temp_val =
937         MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
938 
939     temp_val.val[0] =
940         vaddq_s32(vaddq_s32(temp_val.val[0], output_val_1), output_zp_dup);
941     temp_val.val[1] =
942         vaddq_s32(vaddq_s32(temp_val.val[1], output_val_2), output_zp_dup);
943     temp_val.val[2] =
944         vaddq_s32(vaddq_s32(temp_val.val[2], output_val_3), output_zp_dup);
945     temp_val.val[3] =
946         vaddq_s32(vaddq_s32(temp_val.val[3], output_val_4), output_zp_dup);
947 
948     temp_val.val[0] =
949         vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
950     temp_val.val[1] =
951         vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
952     temp_val.val[2] =
953         vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
954     temp_val.val[3] =
955         vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
956 
957     const int16x8_t result_1 =
958         vcombine_s16(vqmovn_s32(temp_val.val[0]), vqmovn_s32(temp_val.val[1]));
959     const int16x8_t result_2 =
960         vcombine_s16(vqmovn_s32(temp_val.val[2]), vqmovn_s32(temp_val.val[3]));
961     const int8x16_t result =
962         vcombine_s8(vqmovn_s16(result_1), vqmovn_s16(result_2));
963     vst1q_s8(output + i, result);
964   }
965   for (; TFLITE_UNLIKELY(i < total_size); ++i) {
966     int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
967     temp += output_zp;
968     temp += output[i];
969     if (temp > output_max) {
970       temp = output_max;
971     }
972     if (temp < output_min) {
973       temp = output_min;
974     }
975     output[i] = static_cast<int8_t>(temp);
976   }
977 }
978 
NeonCpuBackendGemm(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,CpuBackendContext * context)979 void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
980                         const int8_t* input_to_gate_weights, int32_t n_batch,
981                         int32_t n_input, int32_t n_output, int32_t output_zp,
982                         int32_t* scratch, CpuBackendContext* context) {
983   using ::tflite::cpu_backend_gemm::Gemm;
984   using ::tflite::cpu_backend_gemm::GemmParams;
985   using ::tflite::cpu_backend_gemm::MatrixParams;
986 
987   MatrixParams<int8_t> lhs_params;
988   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
989   lhs_params.rows = n_output;
990   lhs_params.cols = n_input;
991   lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
992 
993   MatrixParams<int8_t> rhs_params;
994   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
995   rhs_params.rows = n_input;
996   rhs_params.cols = n_batch;
997 
998   MatrixParams<int32_t> dst_params;
999   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1000   dst_params.rows = n_output;
1001   dst_params.cols = n_batch;
1002 
1003   GemmParams<int32, int32> gemm_params;
1004   if (bias) {
1005     gemm_params.bias = bias;
1006   }
1007   cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
1008                          dst_params, scratch, gemm_params, context);
1009 }
1010 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int16_t * output,CpuBackendContext * context)1011 void NeonMatrixBatchVectorMultiplyAccumulate(
1012     const int8_t* input, const int32_t* bias,
1013     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
1014     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
1015     int32_t* scratch, int16_t* output, CpuBackendContext* context) {
1016 #ifdef TFLITE_WITH_RUY_GEMV
1017   NeonCpuBackendGemm(input, bias, input_to_gate_weights, n_batch, n_input,
1018                      n_output, output_zp, scratch, context);
1019 #else
1020   NeonMatrixBatchVectorMultiplyImpl(input, bias, input_to_gate_weights, n_batch,
1021                                     n_input, n_output, output_zp, scratch);
1022 #endif
1023   NeonMatrixBatchVectorAccumulateImpl(multiplier, shift, n_batch, n_output,
1024                                       output_zp, scratch, output);
1025 }
1026 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int8_t * output,CpuBackendContext * context)1027 void NeonMatrixBatchVectorMultiplyAccumulate(
1028     const int8_t* input, const int32_t* bias,
1029     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
1030     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
1031     int32_t* scratch, int8_t* output, CpuBackendContext* context) {
1032 #ifdef TFLITE_WITH_RUY_GEMV
1033   NeonCpuBackendGemm(input, bias, input_to_gate_weights, n_batch, n_input,
1034                      n_output, output_zp, scratch, context);
1035 #else
1036   NeonMatrixBatchVectorMultiplyImpl(input, bias, input_to_gate_weights, n_batch,
1037                                     n_input, n_output, output_zp, scratch);
1038 #endif
1039   NeonMatrixBatchVectorAccumulateImpl(multiplier, shift, n_batch, n_output,
1040                                       output_zp, scratch, output);
1041 }
1042 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)1043 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
1044                                              const int m_rows, const int m_cols,
1045                                              const int8_t* __restrict__ vectors,
1046                                              const float* scaling_factors,
1047                                              int n_batch,
1048                                              float* __restrict__ result) {
1049 #ifdef __aarch64__
1050   if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 &&
1051       m_rows >= n_batch) {
1052     if (n_batch % 4 == 0) {
1053       // Benchmarks suggest that it's always better to use the batch code
1054       // when we can, even on small matrices.
1055       DotprodMatrixBatchFourVectorMultiplyAccumulate(
1056           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
1057       return;
1058     } else if (n_batch >= 2 && m_rows * m_cols >= 128 * 128) {
1059       DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
1060           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
1061       return;
1062     }
1063   }
1064 #endif  // __aarch64__
1065 
1066   // Assuming *matrix is kNeonVectorAlignment-byte aligned, every row of the
1067   // matrix is also kNeonVectorAlignment-byte aligned as long as cols is a
1068   // multiple of kNeonVectorAlignment. The assumption is currently satisfied by
1069   // TFLite's 16-byte memory alignment scheme.
1070   //
1071   // Otherwise, we allocate an aligned memory block and set
1072   // a flag to later copy rows from matrix to the block
1073   // for aligned multiplication.
1074   bool unaligned = false;
1075   int8_t* aligned_row = nullptr;
1076   void* aligned_row_free = nullptr;
1077   if ((m_cols & (kNeonVectorAlignment - 1)) != 0) {
1078     unaligned = true;
1079     aligned_row =
1080         (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1081                                &aligned_row_free);
1082   }
1083   void* aligned_vec_free = nullptr;
1084   int8_t* aligned_vec =
1085       (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1086                              &aligned_vec_free);
1087 
1088   // If m_cols is not at least kInt8ValuesPerNeonVector, we cannot use the main
1089   // vectorized loop, and we need to process sequentially. postamble_half_start
1090   // shows the start index where this should happen. Between postamble_start and
1091   // postamble_half_start we can still process kInt8ValuesPerNeonVector/2 in a
1092   // vectorized form.
1093   const int postamble_half_start =
1094       RoundDownVectors<kInt8ValuesPerNeonVector>(m_cols);
1095   const int postamble_start =
1096       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(m_cols);
1097 
1098   for (int batch = 0; batch < n_batch; ++batch) {
1099     const float batch_scaling_factor = scaling_factors[batch];
1100     // Copy the vector data to an aligned vector.
1101     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8_t) * m_cols);
1102     // Compute dot-product for every column.
1103     for (int row = 0; row < m_rows; ++row) {
1104       // Get the address of the first element of the row.
1105       int8_t* row_ptr = (int8_t*)matrix + row * m_cols;  // NOLINT
1106       if (unaligned) {
1107         memcpy(aligned_row, row_ptr, sizeof(int8_t) * m_cols);
1108         row_ptr = aligned_row;
1109       }
1110 
1111       // Initialize the dot product sum for the row to 0.
1112       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
1113 
1114       // Prefetch the row to cache.
1115       __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
1116                          3 /* temporal locality */);
1117 
1118       // For every block of 16 8-bit elements.
1119       int col = 0;
1120       for (; col < postamble_half_start; col += kInt8ValuesPerNeonVector) {
1121         // Load 16 8-bit values from the row and vector, each, to operate on.
1122         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
1123         // performance may suffer significantly.
1124         TFLITE_DCHECK_EQ(  // NOLINT
1125             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1126         const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));
1127         const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));
1128         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
1129         // registers).
1130         int16x8_t prod_16x8 =
1131             vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
1132         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
1133         // registers), and accumulate with the result of the low bits product.
1134         // The assumption here is that overflow will not happen as we quantize
1135         // our values to be in the range [-127, 127]. As such the sum of the 2
1136         // products is always strictly smaller than 15-bits (32767 in absolute
1137         // value).
1138         prod_16x8 =
1139             vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
1140 
1141         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1142       }  // for col
1143 
1144       // Half iteration dealing only 8 elements
1145       if (TFLITE_UNLIKELY(col < postamble_start)) {
1146         // Load 8 8-bit values from the row and column each to operate on.
1147         // Here the assumption is that each buffer is 4-bytes aligned.
1148         // Otherwise, performance may suffer significantly.
1149         TFLITE_DCHECK_EQ(  // NOLINT
1150             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1151         const int8x8_t s1_8x8 = vld1_s8((const int8_t*)(aligned_vec + col));
1152         const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
1153         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
1154         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1155         col += (kInt8ValuesPerNeonVector >> 1);
1156       }
1157       // Add the 4 intermediate sum values to get the final dot-prod value for
1158       // this row.
1159       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
1160       // Postamble loop.
1161       for (; TFLITE_UNLIKELY(col < m_cols); ++col) {
1162         dotprod += row_ptr[col] * aligned_vec[col];
1163       }  // for col
1164 
1165       *result += dotprod * batch_scaling_factor;
1166       ++result;
1167     }  // for row
1168   }    // for batch
1169 
1170   if (unaligned) {
1171     free(aligned_row_free);
1172   }
1173   free(aligned_vec_free);
1174 }
1175 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,int32_t * scratch,float * __restrict__ result,CpuBackendContext * context)1176 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
1177                                              const int m_rows, const int m_cols,
1178                                              const int8_t* __restrict__ vectors,
1179                                              const float* scaling_factors,
1180                                              int n_batch, int32_t* scratch,
1181                                              float* __restrict__ result,
1182                                              CpuBackendContext* context) {
1183   if (m_rows % 4 == 0) {
1184     const int32_t* bias = static_cast<const int32_t*>(nullptr);
1185     NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
1186                        /*output_zp =*/0, scratch, context);
1187 
1188     // Multiply by float scaling factors and write to result
1189     const int total_size = n_batch * m_rows;
1190     int i = 0;
1191     for (; i <= total_size - 8; i += 8, result += 8) {
1192       const float batch_scaling_factor0 = scaling_factors[i / m_rows];
1193       const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
1194       const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
1195       const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
1196       const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
1197       const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
1198       const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
1199       const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
1200       const float32x4_t result0 =
1201           vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
1202       const float32x4_t result1 =
1203           vmlaq_f32(vld1q_f32(result + 4), float_val1, scaling_factor1);
1204       vst1q_f32(result, result0);
1205       vst1q_f32(result + 4, result1);
1206     }
1207     scratch += i;
1208     for (; TFLITE_UNLIKELY(i < total_size); i++) {
1209       const float batch_scaling_factor = scaling_factors[i / m_rows];
1210       int32_t x = *(scratch++);
1211       *result += x * batch_scaling_factor;
1212       ++result;
1213     }
1214     return;
1215   }
1216   NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
1217                                           scaling_factors, n_batch, result);
1218 }
1219 
NeonMatrixScalarMultiplyAccumulate(const int8_t * matrix,int32_t scalar,int32_t n_row,int32_t n_col,int32_t * output)1220 void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
1221                                         int32_t n_row, int32_t n_col,
1222                                         int32_t* output) {
1223   // Processing multiple rows at the same time actually makes it slower. :(
1224   for (int i = 0; i < n_row; ++i) {
1225     int32x4_t row_sum = vdupq_n_s32(0);
1226     int j = 0;
1227     const int8_t* row_ptr = matrix + i * n_col;
1228     for (; j <= n_col - kInt8ValuesPerNeonVector;
1229          j += kInt8ValuesPerNeonVector) {
1230       const int8x16_t input_value = vld1q_s8(row_ptr + j);
1231       int16x8_t temp = vmovl_s8(vget_low_s8(input_value));
1232       temp = vaddw_s8(temp, vget_high_s8(input_value));
1233       row_sum = vpadalq_s16(row_sum, temp);
1234     }
1235     int32_t sum = AccumulateNeonLane(row_sum);
1236     for (; TFLITE_UNLIKELY(j < n_col); ++j) {
1237       sum += *(row_ptr + j);
1238     }
1239     output[i] += sum * scalar;
1240   }
1241 }
1242 
NeonMatrixBatchVectorMultiplyAccumulateImpl(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * row_sums)1243 void NeonMatrixBatchVectorMultiplyAccumulateImpl(
1244     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
1245     const int8_t* __restrict__ vectors, const float* scaling_factors,
1246     int n_batch, float* __restrict__ result, const float* per_channel_scale,
1247     const int32_t* input_offset, int32_t* row_sums) {
1248 #ifdef __aarch64__
1249   if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 &&
1250       m_rows >= n_batch) {
1251     if (n_batch % 4 == 0) {
1252       DotprodMatrixBatchFourVectorMultiplyAccumulate(
1253           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
1254           per_channel_scale, input_offset, row_sums);
1255       return;
1256     } else if (n_batch >= 2 && m_rows * m_cols >= 128 * 128) {
1257       DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
1258           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
1259           per_channel_scale, input_offset, row_sums);
1260       return;
1261     }
1262   }
1263 #endif  // __aarch64__
1264 
1265   bool unaligned = false;
1266   int8_t* aligned_row = nullptr;
1267   void* aligned_row_free = nullptr;
1268   if ((m_cols & (kNeonVectorAlignment - 1)) != 0) {
1269     unaligned = true;
1270     aligned_row =
1271         (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1272                                &aligned_row_free);
1273   }
1274   void* aligned_vec_free = nullptr;
1275   int8_t* aligned_vec =
1276       (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1277                              &aligned_vec_free);
1278 
1279   const int postamble_half_start =
1280       RoundDownVectors<kInt8ValuesPerNeonVector>(m_cols);
1281   const int postamble_start =
1282       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(m_cols);
1283 
1284   int32_t* row_sums_ptr = row_sums;
1285   if (row_sums == nullptr) {
1286     row_sums_ptr = static_cast<int32_t*>(malloc(sizeof(int32_t) * m_rows));
1287     NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols);
1288   }
1289 
1290   for (int batch = 0; batch < n_batch; ++batch) {
1291     const float batch_scaling_factor = scaling_factors[batch];
1292     const int batch_input_offset = input_offset[batch];
1293     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8_t) * m_cols);
1294     for (int row = 0; row < m_rows; ++row) {
1295       int8_t* row_ptr = (int8_t*)matrix + row * m_cols;  // NOLINT
1296       if (unaligned) {
1297         memcpy(aligned_row, row_ptr, sizeof(int8_t) * m_cols);
1298         row_ptr = aligned_row;
1299       }
1300       float scale = batch_scaling_factor;
1301       if (per_channel_scale) {
1302         scale *= per_channel_scale[row];
1303       }
1304       // Initialize the dot product sum for the row to 0.
1305       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
1306 
1307       // Prefetch the row to cache.
1308       __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
1309                          3 /* temporal locality */);
1310 
1311       // For every block of 16 8-bit elements.
1312       int col = 0;
1313       for (; col < postamble_half_start; col += kInt8ValuesPerNeonVector) {
1314         // Load 16 8-bit values from the row and vector, each, to operate on.
1315         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
1316         // performance may suffer significantly.
1317         TFLITE_DCHECK_EQ(  // NOLINT
1318             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1319         const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));
1320         const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));
1321         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
1322         // registers).
1323         int16x8_t prod_16x8 =
1324             vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
1325         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
1326         // registers), and accumulate with the result of the low bits product.
1327         // The assumption here is that overflow will not happen as we quantize
1328         // our values to be in the range [-127, 127]. As such the sum of the 2
1329         // products is always strictly smaller than 15-bits (32767 in absolute
1330         // value).
1331         prod_16x8 =
1332             vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
1333         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1334       }  // for col
1335 
1336       // Half iteration dealing only 8 elements
1337       if (TFLITE_UNLIKELY(col < postamble_start)) {
1338         // Load 8 8-bit values from the row and column each to operate on.
1339         // Here the assumption is that each buffer is 4-bytes aligned.
1340         // Otherwise, performance may suffer significantly.
1341         TFLITE_DCHECK_EQ(  // NOLINT
1342             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1343         const int8x8_t s1_8x8 = vld1_s8((const int8_t*)(aligned_vec + col));
1344         const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
1345         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
1346         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1347         col += (kInt8ValuesPerNeonVector >> 1);
1348       }
1349 
1350       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
1351 
1352       // Postamble loop.
1353       for (; TFLITE_UNLIKELY(col < m_cols); ++col) {
1354         dotprod += row_ptr[col] * aligned_vec[col];
1355       }  // for col
1356       dotprod -= row_sums_ptr[row] * batch_input_offset;
1357       *result += dotprod * scale;
1358       ++result;
1359     }  // for row
1360   }    // for batch
1361 
1362   if (row_sums == nullptr) {
1363     free(row_sums_ptr);
1364   }
1365   if (unaligned) {
1366     free(aligned_row_free);
1367   }
1368   free(aligned_vec_free);
1369 }
1370 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * context)1371 void NeonMatrixBatchVectorMultiplyAccumulate(
1372     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
1373     const int8_t* __restrict__ vectors, const float* scaling_factors,
1374     int n_batch, float* __restrict__ result, const float* per_channel_scale,
1375     const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
1376     bool* compute_row_sums, CpuBackendContext* context) {
1377 #ifdef TFLITE_WITH_RUY_GEMV
1378   const bool use_cpu_backend_gemm = true;
1379 #else
1380   const bool use_cpu_backend_gemm = UseCpuBackendGemm(m_rows, m_cols, n_batch);
1381 #endif
1382   if (input_offset == nullptr) {
1383     if (use_cpu_backend_gemm && context) {
1384       NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
1385                                               scaling_factors, n_batch, scratch,
1386                                               result, context);
1387       return;
1388     }
1389     NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
1390                                             scaling_factors, n_batch, result);
1391     return;
1392   }
1393 
1394   if (compute_row_sums == nullptr || *compute_row_sums) {
1395     NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
1396     if (compute_row_sums) {
1397       *compute_row_sums = false;
1398     }
1399   }
1400 
1401   if (use_cpu_backend_gemm) {
1402     if (context != nullptr && m_rows % 4 == 0) {
1403       const int32_t* bias = static_cast<const int32_t*>(nullptr);
1404       NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
1405                          scratch, context);
1406 
1407       // Multiply by float scaling factors and write to result
1408       const int total_size = n_batch * m_rows;
1409       int i = 0;
1410       int32_t* scratch_ptr = scratch;
1411       for (; i <= total_size - 8; i += 8, result += 8) {
1412         const float batch_scaling_factor0 = scaling_factors[i / m_rows];
1413         const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
1414         const int batch_input_offset0 = -input_offset[i / m_rows];
1415         const int batch_input_offset1 = -input_offset[(i + 4) / m_rows];
1416         float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
1417         float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
1418         if (per_channel_scale) {
1419           const float32x4_t per_channel_scale0 =
1420               vld1q_f32(&per_channel_scale[i % m_rows]);
1421           const float32x4_t per_channel_scale1 =
1422               vld1q_f32(&per_channel_scale[(i + 4) % m_rows]);
1423           scaling_factor0 = vmulq_f32(scaling_factor0, per_channel_scale0);
1424           scaling_factor1 = vmulq_f32(scaling_factor1, per_channel_scale1);
1425         }
1426         const int32x4_t input_offset0 = vdupq_n_s32(batch_input_offset0);
1427         const int32x4_t input_offset1 = vdupq_n_s32(batch_input_offset1);
1428         const int32x4_t row_sum0 = vld1q_s32(row_sums + (i % m_rows));
1429         const int32x4_t row_sum1 = vld1q_s32(row_sums + ((i + 4) % m_rows));
1430         const int32x4_t scratch_val0 = vld1q_s32(scratch_ptr + i);
1431         const int32x4_t scratch_val1 = vld1q_s32(scratch_ptr + i + 4);
1432         const int32x4_t dotprod0 =
1433             vmlaq_s32(scratch_val0, row_sum0, input_offset0);
1434         const int32x4_t dotprod1 =
1435             vmlaq_s32(scratch_val1, row_sum1, input_offset1);
1436         const float32x4_t float_val0 = vcvtq_f32_s32(dotprod0);
1437         const float32x4_t float_val1 = vcvtq_f32_s32(dotprod1);
1438         const float32x4_t result0 =
1439             vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
1440         const float32x4_t result1 =
1441             vmlaq_f32(vld1q_f32(result + 4), float_val1, scaling_factor1);
1442         vst1q_f32(result, result0);
1443         vst1q_f32(result + 4, result1);
1444       }
1445 
1446       scratch_ptr += i;
1447       for (; TFLITE_UNLIKELY(i < total_size); i++) {
1448         float batch_scaling_factor = scaling_factors[i / m_rows];
1449         if (per_channel_scale) {
1450           batch_scaling_factor *= per_channel_scale[i % m_rows];
1451         }
1452         const int32_t zero_point = input_offset[i / m_rows];
1453         int32_t dotprod = *(scratch_ptr++);
1454         dotprod -= row_sums[i % m_rows] * zero_point;
1455         *result += dotprod * batch_scaling_factor;
1456         ++result;
1457       }
1458       return;
1459     }
1460   }
1461 
1462   NeonMatrixBatchVectorMultiplyAccumulateImpl(
1463       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
1464       per_channel_scale, input_offset, row_sums);
1465 }
1466 
MulAdd(int32x4_t acc,int32x4_t lhs,int32x4_t rhs)1467 inline int64x2x2_t MulAdd(int32x4_t acc, int32x4_t lhs, int32x4_t rhs) {
1468   int64x2x2_t result;
1469   const int64x2_t lhs_low = vmovl_s32(vget_low_s32(lhs));
1470   const int64x2_t lhs_high = vmovl_s32(vget_high_s32(lhs));
1471   const int64_t lhs_0 = vgetq_lane_s64(lhs_low, 0);
1472   const int64_t lhs_1 = vgetq_lane_s64(lhs_low, 1);
1473   const int64_t lhs_2 = vgetq_lane_s64(lhs_high, 0);
1474   const int64_t lhs_3 = vgetq_lane_s64(lhs_high, 1);
1475 
1476   const int64x2_t rhs_low = vmovl_s32(vget_low_s32(rhs));
1477   const int64x2_t rhs_high = vmovl_s32(vget_high_s32(rhs));
1478   const int64_t rhs_0 = vgetq_lane_s64(rhs_low, 0);
1479   const int64_t rhs_1 = vgetq_lane_s64(rhs_low, 1);
1480   const int64_t rhs_2 = vgetq_lane_s64(rhs_high, 0);
1481   const int64_t rhs_3 = vgetq_lane_s64(rhs_high, 1);
1482 
1483   const int64x2_t mul_0 = {lhs_0 * rhs_0, lhs_1 * rhs_1};
1484   const int64x2_t mul_1 = {lhs_2 * rhs_2, lhs_3 * rhs_3};
1485 
1486   result.val[0] = vaddq_s64(vmovl_s32(vget_low_s32(acc)), mul_0);
1487   result.val[1] = vaddq_s64(vmovl_s32(vget_high_s32(acc)), mul_1);
1488   return result;
1489 }
1490 
NeonApplyLayerNorm(const int16_t * input,const int16_t * layer_norm_weights,const int32_t * bias,int32_t layer_norm_scale_a,int32_t layer_norm_scale_b,int32_t variance_limit,int n_batch,int n_input,int16_t * output)1491 void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
1492                         const int32_t* bias, int32_t layer_norm_scale_a,
1493                         int32_t layer_norm_scale_b, int32_t variance_limit,
1494                         int n_batch, int n_input, int16_t* output) {
1495   const int32 int16_max = std::numeric_limits<int16>::max();
1496   const int32 int16_min = std::numeric_limits<int16>::min();
1497   const int32 temp = 1048576 / n_input;
1498 
1499   for (int i = 0; i < n_batch; ++i) {
1500     int64_t sum = 0;
1501     int64_t sum_sq = 0;
1502 
1503     int j = 0;
1504     for (; j <= n_input - 8; j += 8) {
1505       const int32 index = i * n_input + j;
1506       const int16x8_t val_s16 = vld1q_s16(input + index);
1507       const int32x4_t val_s32_0 = vmovl_s16(vget_low_s16(val_s16));
1508       const int32x4_t val_s32_1 = vmovl_s16(vget_high_s16(val_s16));
1509 
1510       sum += static_cast<int64_t>(AccumulateNeonLane(val_s32_0));
1511       sum += static_cast<int64_t>(AccumulateNeonLane(val_s32_1));
1512 
1513       sum_sq += static_cast<int64_t>(
1514           AccumulateNeonLane(vmulq_s32(val_s32_0, val_s32_0)));
1515       sum_sq += static_cast<int64_t>(
1516           AccumulateNeonLane(vmulq_s32(val_s32_1, val_s32_1)));
1517     }
1518     for (; TFLITE_UNLIKELY(j < n_input); ++j) {
1519       const int32 index = i * n_input + j;
1520       int32 val = static_cast<int32_t>(input[index]);
1521       sum += val;
1522       sum_sq += val * val;
1523     }
1524 
1525     int32_t mean =
1526         static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input);
1527     // TODO(jianlijianli): Avoids overflow but only works for POT n_input.
1528     int64_t variance =
1529         sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean);
1530     int32_t variance2 = static_cast<int32>(variance / 1048576);
1531     if (variance2 < 1) {
1532       variance2 = variance_limit;
1533     }
1534     int32_t stddev_inverse_a;
1535     int stddev_inverse_b;
1536     GetInvSqrtQuantizedMultiplierExp(variance2, /*reverse_shift*/ -1,
1537                                      &stddev_inverse_a, &stddev_inverse_b);
1538 
1539     j = 0;
1540     const int32x4_t mean_dup = vdupq_n_s32(mean);
1541     for (; j <= n_input - 16; j += 16) {
1542       // Load 16 items at once.
1543       const int32 index = i * n_input + j;
1544       const int16x8_t val_s16_0 = vld1q_s16(input + index);
1545       const int16x8_t val_s16_1 = vld1q_s16(input + index + 8);
1546 
1547       int32x4x4_t shifted;
1548       shifted.val[0] = vsubq_s32(
1549           vshlq_n_s32(vmovl_s16(vget_low_s16(val_s16_0)), 10), mean_dup);
1550       shifted.val[1] = vsubq_s32(
1551           vshlq_n_s32(vmovl_s16(vget_high_s16(val_s16_0)), 10), mean_dup);
1552       shifted.val[2] = vsubq_s32(
1553           vshlq_n_s32(vmovl_s16(vget_low_s16(val_s16_1)), 10), mean_dup);
1554       shifted.val[3] = vsubq_s32(
1555           vshlq_n_s32(vmovl_s16(vget_high_s16(val_s16_1)), 10), mean_dup);
1556 
1557       int32x4x4_t rescaled = MultiplyByQuantizedMultiplier4Rows(
1558           shifted, stddev_inverse_a, stddev_inverse_b);
1559 
1560       const int32x4_t bias_0 = vld1q_s32(bias + j);
1561       const int32x4_t bias_1 = vld1q_s32(bias + j + 4);
1562       const int32x4_t bias_2 = vld1q_s32(bias + j + 8);
1563       const int32x4_t bias_3 = vld1q_s32(bias + j + 12);
1564 
1565       const int16x8_t layer_norm_weights_s16_0 =
1566           vld1q_s16(layer_norm_weights + j);
1567       const int16x8_t layer_norm_weights_s16_1 =
1568           vld1q_s16(layer_norm_weights + j + 8);
1569       const int32x4_t layer_norm_weights_s32_0 =
1570           vmovl_s16(vget_low_s16(layer_norm_weights_s16_0));
1571       const int32x4_t layer_norm_weights_s32_1 =
1572           vmovl_s16(vget_high_s16(layer_norm_weights_s16_0));
1573       const int32x4_t layer_norm_weights_s32_2 =
1574           vmovl_s16(vget_low_s16(layer_norm_weights_s16_1));
1575       const int32x4_t layer_norm_weights_s32_3 =
1576           vmovl_s16(vget_high_s16(layer_norm_weights_s16_1));
1577 
1578       int64x2x2_t val3_0 =
1579           MulAdd(bias_0, rescaled.val[0], layer_norm_weights_s32_0);
1580       int64x2x2_t val3_1 =
1581           MulAdd(bias_1, rescaled.val[1], layer_norm_weights_s32_1);
1582       int64x2x2_t val3_2 =
1583           MulAdd(bias_2, rescaled.val[2], layer_norm_weights_s32_2);
1584       int64x2x2_t val3_3 =
1585           MulAdd(bias_3, rescaled.val[3], layer_norm_weights_s32_3);
1586 
1587       int32x4x4_t val4;
1588       val4.val[0] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_0.val[0], 10)),
1589                                  vmovn_s64(vrshrq_n_s64(val3_0.val[1], 10)));
1590       val4.val[1] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_1.val[0], 10)),
1591                                  vmovn_s64(vrshrq_n_s64(val3_1.val[1], 10)));
1592       val4.val[2] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_2.val[0], 10)),
1593                                  vmovn_s64(vrshrq_n_s64(val3_2.val[1], 10)));
1594       val4.val[3] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_3.val[0], 10)),
1595                                  vmovn_s64(vrshrq_n_s64(val3_3.val[1], 10)));
1596 
1597       int32x4x4_t val5_s32 = MultiplyByQuantizedMultiplier4Rows(
1598           val4, layer_norm_scale_a, layer_norm_scale_b + 12);
1599       vst1_s16(output + index, vqmovn_s32(val5_s32.val[0]));
1600       vst1_s16(output + index + 4, vqmovn_s32(val5_s32.val[1]));
1601       vst1_s16(output + index + 8, vqmovn_s32(val5_s32.val[2]));
1602       vst1_s16(output + index + 12, vqmovn_s32(val5_s32.val[3]));
1603     }
1604     for (; TFLITE_UNLIKELY(j < n_input); ++j) {
1605       const int32 index = i * n_input + j;
1606       int32 val = static_cast<int32_t>(input[index]);
1607       int32 shifted = 1024 * val - mean;
1608       int32 rescaled = MultiplyByQuantizedMultiplier(shifted, stddev_inverse_a,
1609                                                      stddev_inverse_b);
1610       // TODO(jianlijianli): Saturate this.
1611       int64_t val3 = rescaled * layer_norm_weights[j] + bias[j];
1612       int32 val4 =
1613           static_cast<int32>((val3 > 0 ? val3 + 512 : val3 - 512) / 1024);
1614       int32 val5 = MultiplyByQuantizedMultiplier(val4, layer_norm_scale_a,
1615                                                  layer_norm_scale_b + 12);
1616       val5 = std::min(std::max(int16_min, val5), int16_max);
1617       output[index] = static_cast<int16_t>(val5);
1618     }
1619   }
1620 }
1621 
NeonApplySigmoid(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)1622 void NeonApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
1623                       int16_t* output) {
1624   for (int batch = 0; batch < n_batch; ++batch) {
1625     int i = 0;
1626 #ifdef GEMMLOWP_NEON
1627     // F0 uses 0 integer bits, range [-1, 1].
1628     // This is the return type of math functions such as tanh, logistic,
1629     // whose range is in [-1, 1].
1630     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
1631     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
1632     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
1633 
1634     for (; i <= n_input - 32; i += 32) {
1635       const int index = batch * n_input + i;
1636       F3 input0 = F3::FromRaw(vld1q_s16(input + index));
1637       F3 input1 = F3::FromRaw(vld1q_s16(input + index + 8));
1638       F3 input2 = F3::FromRaw(vld1q_s16(input + index + 16));
1639       F3 input3 = F3::FromRaw(vld1q_s16(input + index + 24));
1640       F0 output0 = gemmlowp::logistic(input0);
1641       F0 output1 = gemmlowp::logistic(input1);
1642       F0 output2 = gemmlowp::logistic(input2);
1643       F0 output3 = gemmlowp::logistic(input3);
1644       vst1q_s16(output + index, output0.raw());
1645       vst1q_s16(output + index + 8, output1.raw());
1646       vst1q_s16(output + index + 16, output2.raw());
1647       vst1q_s16(output + index + 24, output3.raw());
1648     }
1649 #endif  // GEMMLOWP_NEON
1650     using F0_Scalar = gemmlowp::FixedPoint<int16_t, 0>;
1651     using F3_Scalar = gemmlowp::FixedPoint<int16_t, 3>;
1652     for (; i < n_input; ++i) {
1653       const int index = batch * n_input + i;
1654       F3_Scalar input_f3 = F3_Scalar::FromRaw(input[index]);
1655       F0_Scalar output_f0 = gemmlowp::logistic(input_f3);
1656       output[index] = output_f0.raw();
1657     }
1658   }
1659 }
1660 
1661 template <int IntegerBits>
NeonApplyTanhImpl(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)1662 void NeonApplyTanhImpl(const int16_t* input, int32_t n_batch, int32_t n_input,
1663                        int16_t* output) {
1664   for (int batch = 0; batch < n_batch; ++batch) {
1665     int i = 0;
1666 #ifdef GEMMLOWP_NEON
1667     // F0 uses 0 integer bits, range [-1, 1].
1668     // This is the return type of math functions such as tanh, logistic,
1669     // whose range is in [-1, 1].
1670     using F_In = gemmlowp::FixedPoint<int16x8_t, IntegerBits>;
1671     using F_Out = gemmlowp::FixedPoint<int16x8_t, 0>;
1672 
1673     for (; i <= n_input - 32; i += 32) {
1674       const int index = batch * n_input + i;
1675       F_In input0 = F_In::FromRaw(vld1q_s16(input + index));
1676       F_In input1 = F_In::FromRaw(vld1q_s16(input + index + 8));
1677       F_In input2 = F_In::FromRaw(vld1q_s16(input + index + 16));
1678       F_In input3 = F_In::FromRaw(vld1q_s16(input + index + 24));
1679       F_Out output0 = gemmlowp::tanh(input0);
1680       F_Out output1 = gemmlowp::tanh(input1);
1681       F_Out output2 = gemmlowp::tanh(input2);
1682       F_Out output3 = gemmlowp::tanh(input3);
1683       vst1q_s16(output + index, output0.raw());
1684       vst1q_s16(output + index + 8, output1.raw());
1685       vst1q_s16(output + index + 16, output2.raw());
1686       vst1q_s16(output + index + 24, output3.raw());
1687     }
1688 #endif  // GEMMLOWP_NEON
1689     using F_In_Scalar = gemmlowp::FixedPoint<int16_t, IntegerBits>;
1690     using F_Out_Scalar = gemmlowp::FixedPoint<int16_t, 0>;
1691     for (; i < n_input; ++i) {
1692       const int index = batch * n_input + i;
1693       F_In_Scalar input_in = F_In_Scalar::FromRaw(input[index]);
1694       F_Out_Scalar output_out = gemmlowp::tanh(input_in);
1695       output[index] = output_out.raw();
1696     }
1697   }
1698 }
1699 
NeonApplyTanh(int32_t integer_bits,const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)1700 void NeonApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
1701                    int32_t n_input, int16_t* output) {
1702   assert(integer_bits <= 6);
1703 #define DISPATCH_TANH(i)                                   \
1704   case i:                                                  \
1705     NeonApplyTanhImpl<i>(input, n_batch, n_input, output); \
1706     break;
1707   switch (integer_bits) {
1708     DISPATCH_TANH(0);
1709     DISPATCH_TANH(1);
1710     DISPATCH_TANH(2);
1711     DISPATCH_TANH(3);
1712     DISPATCH_TANH(4);
1713     DISPATCH_TANH(5);
1714     DISPATCH_TANH(6);
1715     default:
1716       return;
1717   }
1718 #undef DISPATCH_TANH
1719 }
1720 
NeonCwiseMul(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int shift,int16_t * output)1721 void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
1722                   int n_input, int shift, int16_t* output) {
1723   for (int batch = 0; batch < n_batch; ++batch) {
1724     int i = 0;
1725     for (; i <= n_input - 8; i += 8) {
1726       const int index = batch * n_input + i;
1727       const int16x8_t a = vld1q_s16(input_1 + index);
1728       const int16x8_t b = vld1q_s16(input_2 + index);
1729       const int32x4_t a_s32_0 = vmovl_s16(vget_low_s16(a));
1730       const int32x4_t a_s32_1 = vmovl_s16(vget_high_s16(a));
1731       const int32x4_t b_s32_0 = vmovl_s16(vget_low_s16(b));
1732       const int32x4_t b_s32_1 = vmovl_s16(vget_high_s16(b));
1733 
1734       int32x4_t x_0 = vmulq_s32(a_s32_0, b_s32_0);
1735       int32x4_t x_1 = vmulq_s32(a_s32_1, b_s32_1);
1736       x_0 = gemmlowp::RoundingDivideByPOT(x_0, shift);
1737       x_1 = gemmlowp::RoundingDivideByPOT(x_1, shift);
1738 
1739       const int16x8_t result = vcombine_s16(vmovn_s32(x_0), vmovn_s32(x_1));
1740       vst1q_s16(output + index, result);
1741     }
1742     for (; TFLITE_UNLIKELY(i < n_input); ++i) {
1743       const int index = batch * n_input + i;
1744       const int16_t a = input_1[index];
1745       const int16_t b = input_2[index];
1746       int64_t x = a * b;
1747       if (x > std::numeric_limits<std::int32_t>::max()) {
1748         x = std::numeric_limits<std::int32_t>::max();
1749       }
1750       const int32_t value = static_cast<int32_t>(x);
1751       output[index] =
1752           static_cast<int16_t>(gemmlowp::RoundingDivideByPOT(value, shift));
1753     }
1754   }
1755 }
1756 
NeonCwiseMul(const int16_t * input_1,const int16_t * input_2,int32_t multiplier,int shift,int n_batch,int n_input,int32_t output_zp,int8_t * output)1757 void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2,
1758                   int32_t multiplier, int shift, int n_batch, int n_input,
1759                   int32_t output_zp, int8_t* output) {
1760   const int32_t output_min = std::numeric_limits<int8_t>::min();
1761   const int32_t output_max = std::numeric_limits<int8_t>::max();
1762 
1763   const int32x4_t output_zp_dup = vdupq_n_s32(-output_zp);
1764   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
1765   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
1766 
1767   for (int batch = 0; batch < n_batch; ++batch) {
1768     int i = 0;
1769     for (; i <= n_input - 8; i += 8) {
1770       const int index = batch * n_input + i;
1771       const int16x8_t a = vld1q_s16(input_1 + index);
1772       const int16x8_t b = vld1q_s16(input_2 + index);
1773       const int32x4_t a_s32_0 = vmovl_s16(vget_low_s16(a));
1774       const int32x4_t a_s32_1 = vmovl_s16(vget_high_s16(a));
1775       const int32x4_t b_s32_0 = vmovl_s16(vget_low_s16(b));
1776       const int32x4_t b_s32_1 = vmovl_s16(vget_high_s16(b));
1777 
1778       int32x4x2_t temp_val;
1779       temp_val.val[0] = vmulq_s32(a_s32_0, b_s32_0);
1780       temp_val.val[1] = vmulq_s32(a_s32_1, b_s32_1);
1781       temp_val =
1782           MultiplyByQuantizedMultiplier2Rows(temp_val, multiplier, shift);
1783 
1784       temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
1785       temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
1786       temp_val.val[0] =
1787           vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
1788       temp_val.val[1] =
1789           vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
1790 
1791       const int16x8_t result =
1792           vcombine_s16(vmovn_s32(temp_val.val[0]), vmovn_s32(temp_val.val[1]));
1793       vst1_s8(output + index, vmovn_s16(result));
1794     }
1795     for (; TFLITE_UNLIKELY(i < n_input); ++i) {
1796       const int index = batch * n_input + i;
1797       const int16_t a = input_1[index];
1798       const int16_t b = input_2[index];
1799       int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
1800       value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
1801       value -= output_zp;
1802       value = std::min(std::max(-128, value), 127);
1803 
1804       output[index] = static_cast<int8>(value);
1805     }
1806   }
1807 }
1808 
NeonCwiseAdd(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int16_t * output)1809 void NeonCwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
1810                   int n_input, int16_t* output) {
1811   const int32 int16_max = std::numeric_limits<int16>::max();
1812   const int32 int16_min = std::numeric_limits<int16>::min();
1813   for (int batch = 0; batch < n_batch; ++batch) {
1814     int i = 0;
1815     for (; i <= n_input - 8; i += 8) {
1816       const int index = batch * n_input + i;
1817       const int16x8_t a = vld1q_s16(input_1 + index);
1818       const int16x8_t b = vld1q_s16(input_2 + index);
1819       const int32x4_t a_s32_0 = vmovl_s16(vget_low_s16(a));
1820       const int32x4_t a_s32_1 = vmovl_s16(vget_high_s16(a));
1821       const int32x4_t b_s32_0 = vmovl_s16(vget_low_s16(b));
1822       const int32x4_t b_s32_1 = vmovl_s16(vget_high_s16(b));
1823 
1824       const int32x4_t sum_0 = vaddq_s32(a_s32_0, b_s32_0);
1825       const int32x4_t sum_1 = vaddq_s32(a_s32_1, b_s32_1);
1826       vst1_s16(output + index, vqmovn_s32(sum_0));
1827       vst1_s16(output + index + 4, vqmovn_s32(sum_1));
1828     }
1829     for (; TFLITE_UNLIKELY(i < n_input); ++i) {
1830       const int index = batch * n_input + i;
1831       int32_t sum = input_1[index] + input_2[index];
1832       const int32 sum_clamped = std::min(int16_max, std::max(int16_min, sum));
1833       output[index] = static_cast<int16_t>(sum_clamped);
1834     }
1835   }
1836 }
1837 
NeonCwiseClipping(float * vector,const int v_size,const float clipping_value)1838 void NeonCwiseClipping(float* vector, const int v_size,
1839                        const float clipping_value) {
1840   const float32x4_t clipping_value_f32x4 = vmovq_n_f32(clipping_value);
1841   const float32x4_t neg_clipping_value_f32x4 = vmovq_n_f32(-clipping_value);
1842 
1843   int i = 0;
1844   for (; i <= v_size - kFloatValuesPerNeonVector;
1845        i += kFloatValuesPerNeonVector) {
1846     // Load from memory to vector.
1847     float32x4_t v_f32x4 = vld1q_f32(vector + i);
1848     // Clip between clipping_value and -clipping_value.
1849     v_f32x4 = vminq_f32(clipping_value_f32x4, v_f32x4);
1850     v_f32x4 = vmaxq_f32(neg_clipping_value_f32x4, v_f32x4);
1851     // Save to output.
1852     vst1q_f32(vector + i, v_f32x4);
1853   }
1854   for (; TFLITE_UNLIKELY(i < v_size); i++) {
1855     vector[i] = std::max(std::min(clipping_value, vector[i]), -clipping_value);
1856   }
1857 }
1858 
NeonCwiseClipping(int16_t * vector,const int v_size,const int16_t clipping_value)1859 void NeonCwiseClipping(int16_t* vector, const int v_size,
1860                        const int16_t clipping_value) {
1861   const int16x8_t max_dup = vdupq_n_s16(clipping_value);
1862   const int16x8_t min_dup = vdupq_n_s16(-clipping_value);
1863 
1864   int i = 0;
1865   for (; i <= v_size - kInt16ValuesPerNeonVector * 2;
1866        i += kInt16ValuesPerNeonVector * 2) {
1867     int16x8_t val_0 = vld1q_s16(vector + i);
1868     int16x8_t val_1 = vld1q_s16(vector + i + kInt16ValuesPerNeonVector);
1869     val_0 = vminq_s16(val_0, max_dup);
1870     val_1 = vminq_s16(val_1, max_dup);
1871     val_0 = vmaxq_s16(val_0, min_dup);
1872     val_1 = vmaxq_s16(val_1, min_dup);
1873     vst1q_s16(vector + i, val_0);
1874     vst1q_s16(vector + i + kInt16ValuesPerNeonVector, val_1);
1875   }
1876   for (; TFLITE_UNLIKELY(i < v_size); i++) {
1877     vector[i] = std::max(std::min(clipping_value, vector[i]),
1878                          static_cast<int16_t>(-clipping_value));
1879   }
1880 }
1881 
NeonCwiseClipping(int8_t * vector,const int v_size,const int8_t clipping_value)1882 void NeonCwiseClipping(int8_t* vector, const int v_size,
1883                        const int8_t clipping_value) {
1884   const int8x16_t max_dup = vdupq_n_s8(clipping_value);
1885   const int8x16_t min_dup = vdupq_n_s8(-clipping_value);
1886 
1887   int i = 0;
1888   for (; i < v_size - kInt8ValuesPerNeonVector * 2;
1889        i += kInt8ValuesPerNeonVector * 2) {
1890     int8x16_t val_0 = vld1q_s8(vector + i);
1891     int8x16_t val_1 = vld1q_s8(vector + i + kInt8ValuesPerNeonVector);
1892     val_0 = vminq_s8(val_0, max_dup);
1893     val_1 = vminq_s8(val_1, max_dup);
1894     val_0 = vmaxq_s8(val_0, min_dup);
1895     val_1 = vmaxq_s8(val_1, min_dup);
1896     vst1q_s8(vector + i, val_0);
1897     vst1q_s8(vector + i + kInt8ValuesPerNeonVector, val_1);
1898   }
1899   for (; TFLITE_UNLIKELY(i < v_size); i++) {
1900     vector[i] = std::max(std::min(clipping_value, vector[i]),
1901                          static_cast<int8_t>(-clipping_value));
1902   }
1903 }
1904 
NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(const float * __restrict__ matrix,const int32_t * __restrict__ segments,const int32_t * __restrict__ indices,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)1905 void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(
1906     const float* __restrict__ matrix, const int32_t* __restrict__ segments,
1907     const int32_t* __restrict__ indices, int m_rows, int m_cols,
1908     const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
1909   constexpr int kBlockSize = kFloatValuesPerNeonVector;
1910   TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
1911 
1912   for (int batch = 0; batch < n_batch; batch++) {
1913     const float* matrix_ptr = matrix;
1914     for (int row = 0; row < m_rows; row++) {
1915       float32x4_t acc_32x4 = vmovq_n_f32(0.0);
1916       const float* vector_in_batch = vector + batch * m_cols;
1917 
1918       for (int i = segments[row]; i < segments[row + 1]; i++) {
1919         const int block_start_index = indices[i] * kBlockSize;
1920         const float* vector_block_in_batch_ptr =
1921             vector_in_batch + block_start_index;
1922 
1923         // Load 4 float values from the vector and matrix row.
1924         float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr);
1925         float32x4_t matrix_f32x4 = vld1q_f32(matrix_ptr);
1926         // Multiply the vector and matrix row and add to accumulator.
1927         acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
1928         matrix_ptr += kBlockSize;
1929       }
1930       result[batch * m_rows + row] += AccumulateNeonLane(acc_32x4);
1931     }
1932   }
1933 }
1934 
NeonSparseMatrixBatchVectorMultiplyAccumulate(const float * __restrict__ matrix,const uint8_t * __restrict__ ledger,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)1935 void NeonSparseMatrixBatchVectorMultiplyAccumulate(
1936     const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
1937     int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
1938     float* __restrict__ result) {
1939   constexpr int kNeonVectorsPerBlock = 4;
1940   constexpr int kBlockSize = kNeonVectorsPerBlock * kFloatValuesPerNeonVector;
1941   TFLITE_DCHECK_EQ(  // NOLINT
1942       m_cols % kBlockSize, 0);
1943 
1944   for (int batch = 0; batch < n_batch; batch++) {
1945     const float* matrix_ptr = matrix;
1946     const uint8_t* ledger_ptr = ledger;
1947     for (int row = 0; row < m_rows; row++) {
1948       int num_nonzero_blocks = *ledger_ptr++;
1949       if (num_nonzero_blocks > 0) {
1950         float32x4_t acc_32x4 = vmovq_n_f32(0.0);
1951         const float* vector_in_batch = vector + batch * m_cols;
1952 
1953         for (int i = 0; i < num_nonzero_blocks; i++) {
1954           const int block_start_index = *ledger_ptr++ * kBlockSize;
1955           const float* vector_block_in_batch_ptr =
1956               vector_in_batch + block_start_index;
1957 
1958           for (int c = 0; c < kNeonVectorsPerBlock; c++) {
1959             // Load 4 float values from the vector and matrix row.
1960             float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr +
1961                                                  c * kFloatValuesPerNeonVector);
1962             float32x4_t matrix_f32x4 =
1963                 vld1q_f32(matrix_ptr + c * kFloatValuesPerNeonVector);
1964             // Multiply the vector and matrix row and add to accumulator.
1965             acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
1966           }
1967           matrix_ptr += kBlockSize;
1968         }
1969         result[batch * m_rows + row] += AccumulateNeonLane(acc_32x4);
1970       }
1971     }
1972   }
1973 }
1974 
NeonSparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)1975 void NeonSparseMatrixBatchVectorMultiplyAccumulate(
1976     const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
1977     const int m_cols, const int8_t* __restrict__ vectors,
1978     const float* scaling_factors, int n_batch, float* __restrict__ result) {
1979 #ifdef __aarch64__
1980   if (HasSdotInstruction() && m_cols % 16 == 0) {
1981     DotprodSparseMatrixBatchVectorMultiplyAccumulate(
1982         matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch,
1983         result);
1984     return;
1985   }
1986 #endif  // __aarch64__
1987 
1988   constexpr int kBlockSize = kInt8ValuesPerNeonVector;
1989   TFLITE_DCHECK_EQ(  // NOLINT
1990       m_cols % kBlockSize, 0);
1991   void* aligned_vec_free = nullptr;
1992   int8_t* aligned_vec =
1993       (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1994                              &aligned_vec_free);
1995 
1996   for (int batch = 0; batch < n_batch; ++batch) {
1997     const float batch_scaling_factor = scaling_factors[batch];
1998     // Copy the vector data to an aligned vector.
1999     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols);
2000 
2001     const uint8_t* ledger_ptr = ledger;
2002     const int8_t* row_ptr = matrix;
2003     for (int row = 0; row < m_rows; ++row) {
2004       // Initialize the dot product sum for the row to 0.
2005       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
2006       int num_nonzero_blocks = *ledger_ptr++;
2007       if (num_nonzero_blocks > 0) {
2008         // Prefetch the row to cache.
2009         __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
2010                            3 /* temporal locality */);
2011         for (int i = 0; i < num_nonzero_blocks; i++) {
2012           const int col_index = *ledger_ptr++ * kBlockSize;
2013           // Load 16 8-bit values from the row and vector, each, to operate on.
2014           // Here the assumption is that each buffer is 4-byte aligned.
2015           // Otherwise, performance may suffer significantly.
2016           TFLITE_DCHECK_EQ(  // NOLINT
2017               (uintptr_t)(&row_ptr) & (kNeonVectorAlignment - 1), 0);
2018           const int8x16_t s1_8x16 =
2019               vld1q_s8((const int8_t*)(aligned_vec + col_index));
2020           const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr));
2021           // Multiply the low bits (i.e. the lower 8 8bit numbers in the
2022           // registers).
2023           int16x8_t prod_16x8 =
2024               vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
2025           // Multiply the high bits (i.e. the lower 8 8bit numbers in the
2026           // registers), and accumulate with the result of the low bits product.
2027           // The assumption here is that overflow will not happen as we quantize
2028           // our values to be in the range [-127, 127]. As such the sum of the 2
2029           // products is always strictly smaller than 15-bits (32767 in absolute
2030           // value).
2031           prod_16x8 =
2032               vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
2033 
2034           dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
2035           row_ptr += kBlockSize;
2036         }
2037         // Add the 4 intermediate sum values to get the final dot-prod value for
2038         // this row.
2039         int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
2040         result[batch * m_rows + row] += dotprod * batch_scaling_factor;
2041       }
2042     }  // for row
2043   }    // for batch
2044   free(aligned_vec_free);
2045 }
2046 
NeonSub1Vector(const float * vector,int v_size,float * result)2047 void NeonSub1Vector(const float* vector, int v_size, float* result) {
2048   // If v_size is not divisible by the vector size, then we need to process the
2049   // final few elements sequentially. postamble_start shows the start index
2050   // where this should happen.
2051   const int postamble_start =
2052       RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
2053 
2054   float32x4_t one_f32x4 = vmovq_n_f32(1.0);
2055   int v = 0;
2056   for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
2057     // Load 4 float values from the current pointers of the input column and
2058     // subtract from 1.
2059     float32x4_t v_f32x4 = vld1q_f32(vector + v);
2060     float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4);
2061     // Save to output.
2062     vst1q_f32(result + v, result_f32x4);
2063   }
2064   for (; TFLITE_UNLIKELY(v < v_size); v++) {
2065     result[v] = 1.0f - vector[v];
2066   }
2067 }
2068 
NeonSub1Vector(const int16_t * vector,int v_size,int16_t * result)2069 void NeonSub1Vector(const int16_t* vector, int v_size, int16_t* result) {
2070   int postamble_start = RoundDownVectors<kInt16ValuesPerNeonVector>(v_size);
2071   static const int16_t kOne = 32767;
2072   // Use xor to replace substract from 1 << 15 - 1.
2073   // Local benchmark shows it's slightly faster than pure substract.
2074   const int16x8_t one_dup = vdupq_n_s16(kOne);
2075   int i = 0;
2076   for (; i < postamble_start; i += kInt16ValuesPerNeonVector) {
2077     const int16x8_t input = vld1q_s16(vector + i);
2078     const int16x8_t sub1_result = veorq_s16(one_dup, input);
2079     vst1q_s16(result + i, sub1_result);
2080   }
2081   for (; TFLITE_UNLIKELY(i < v_size); i++) {
2082     result[i] = kOne ^ vector[i];
2083   }
2084 }
2085 
2086 namespace {
2087 
2088 #ifdef __aarch64__
IsAllZero(const int8x16_t v_s8x16)2089 inline bool IsAllZero(const int8x16_t v_s8x16) {
2090   const uint32_t u32 = vmaxvq_u32(vreinterpretq_u32_s8(v_s8x16));
2091   return !u32;
2092 }
2093 
IsAllZero(const float32x4_t v_f32x4)2094 inline bool IsAllZero(const float32x4_t v_f32x4) {
2095   const uint32x4_t cmp_result = vceqzq_f32(v_f32x4);
2096   const uint32_t u32 = vminvq_u32(cmp_result);
2097   return u32;
2098 }
2099 #else
2100 inline bool IsAllZero(const uint32x4_t u32x4) {
2101   const uint32x2_t u32x2 = vqadd_u32(vget_high_u32(u32x4), vget_low_u32(u32x4));
2102   const uint64x1_t u64 = vreinterpret_u64_u32(u32x2);
2103   return !vget_lane_u64(u64, 0);
2104 }
2105 
2106 #ifndef __SSE__
2107 // With Intel NEON-2-SSE translator library, this is a redefinition..
2108 inline bool IsAllZero(const int8x16_t v) {
2109   return IsAllZero(vreinterpretq_u32_s8(v));
2110 }
2111 #endif
2112 
2113 inline bool IsAllZero(const float32x4_t v_f32x4) {
2114   const float32x4_t zero_f32x4 = vmovq_n_f32(0.0f);
2115   // Compare-absolute greater-than, |v| > |0|, equivalently v != 0
2116   const uint32x4_t cmp_result = vcagtq_f32(v_f32x4, zero_f32x4);
2117   return IsAllZero(cmp_result);
2118 }
2119 #endif
2120 
2121 }  // namespace
2122 
NeonIsZeroVector(const float * vector,int v_size)2123 bool NeonIsZeroVector(const float* vector, int v_size) {
2124   // If v_size is not divisible by the vector size, then we need to process the
2125   // final few elements sequentially. postamble_start shows the start index
2126   // where this should happen.
2127   const int postamble_start =
2128       RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
2129 
2130   int v = 0;
2131   for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
2132     const float32x4_t v_f32x4 = vld1q_f32(vector + v);
2133     if (!IsAllZero(v_f32x4)) return false;
2134   }
2135   // Postamble loop
2136   for (; TFLITE_UNLIKELY(v < v_size); ++v) {
2137     if (vector[v] != 0.0) return false;
2138   }
2139   return true;
2140 }
2141 
NeonIsZeroVector(const int8_t * vector,int v_size)2142 bool NeonIsZeroVector(const int8_t* vector, int v_size) {
2143   // If v_size is not divisible by the vector size, then we need to process the
2144   // final few elements sequentially. postamble_start shows the start index
2145   // where this should happen.
2146   const int postamble_start =
2147       RoundDownVectors<kInt8ValuesPerNeonVector>(v_size);
2148 
2149   int v = 0;
2150   for (; v < postamble_start; v += kInt8ValuesPerNeonVector) {
2151     const int8x16_t v_s8x16 = vld1q_s8(vector + v);
2152     if (!IsAllZero(v_s8x16)) return false;
2153   }
2154   // Postamble loop
2155   for (; TFLITE_UNLIKELY(v < v_size); ++v) {
2156     if (vector[v] != 0) return false;
2157   }
2158   return true;
2159 }
2160 
NeonVectorScalarMultiply(const int8_t * vector,const int v_size,const float scale,float * result)2161 void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
2162                               const float scale, float* result) {
2163   // Here the assumption is that each buffer is 4-byte aligned.
2164   TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kNeonVectorAlignment - 1), 0);
2165   // If v_size is not divisible by kInt8ValuesPerNeonVector, we cannot use the
2166   // main vectorized loop, and we need to process sequentially. postamble_start
2167   // shows the start index where this should happen.
2168   const int postamble_start =
2169       RoundDownVectors<kInt8ValuesPerNeonVector>(v_size);
2170 
2171   // Create a vector of 4 floats with the scale value.
2172   const float32x4_t scale_f32x4 = vdupq_n_f32(scale);
2173   int v = 0;
2174   for (; v < postamble_start; v += kInt8ValuesPerNeonVector) {
2175     // Load int8 values, sixteen at a time.
2176     const int8x16_t v_i8x16 = vld1q_s8(vector + v);
2177     // Split it into two components of size eight.
2178     const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16);
2179     const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16);
2180     // Convert both components to int16 first.
2181     const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8);
2182     const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8);
2183     // Split each of them into two components each.
2184     const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8);
2185     const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8);
2186     const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8);
2187     const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8);
2188     // Convert these to int32 and then to float.
2189     float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
2190     float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
2191     float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4));
2192     float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4));
2193     // Vector multiply four floats at a time.
2194     v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
2195     v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
2196     v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4);
2197     v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4);
2198     // Store the results.
2199     vst1q_f32(result + v, v0_f32x4);
2200     vst1q_f32(result + v + 4, v1_f32x4);
2201     vst1q_f32(result + v + 8, v2_f32x4);
2202     vst1q_f32(result + v + 12, v3_f32x4);
2203   }
2204 
2205   if (TFLITE_UNLIKELY(v_size - postamble_start >=
2206                       (kInt8ValuesPerNeonVector >> 1))) {
2207     // Load eight int8 values, if there is at least eight remaining.
2208     const int8x8_t v_i8x8 = vld1_s8(vector + v);
2209     // Convert them to int16 first.
2210     const int16x8_t v_i16x8 = vmovl_s8(v_i8x8);
2211     // Split it into two components.
2212     const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8);
2213     const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8);
2214     // Convert the components two floats.
2215     float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
2216     float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
2217     // Vector multiply four floats at a time.
2218     v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
2219     v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
2220     // Store the results.
2221     vst1q_f32(result + v, v0_f32x4);
2222     vst1q_f32(result + v + 4, v1_f32x4);
2223     v += (kInt8ValuesPerNeonVector >> 1);
2224   }
2225 
2226   // Postamble loop.
2227   for (; TFLITE_UNLIKELY(v < v_size); v++) {
2228     result[v] = scale * vector[v];
2229   }
2230 }
2231 
2232 // TODO(renjieliu): Avoid duplicating the logic.
2233 // Also consider changing the rounding stragey from "ties to away" to
2234 // "ties to even" since vcvtnq_s32_f32 is generally more available.
RoundToNearest(const float32x4_t input)2235 inline int32x4_t RoundToNearest(const float32x4_t input) {
2236 #if __ARM_ARCH >= 8
2237   return vcvtaq_s32_f32(input);
2238 #else
2239   static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
2240   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
2241 
2242   const int32x4_t mask = vreinterpretq_s32_u32(vcltq_f32(input, zero_val_dup));
2243   const float32x4_t casted_mask = vcvtq_f32_s32(mask);
2244   const float32x4_t round = vaddq_f32(casted_mask, point5_val_dup);
2245   return vcvtq_s32_f32(vaddq_f32(input, round));
2246 #endif
2247 }
2248 
2249 // Note: this function caps minimum and maximum at zero, unlike the true
2250 // minmax_element. This is intentional.
NeonMinMax(const float * values,const int size,float * min,float * max)2251 inline void NeonMinMax(const float* values, const int size, float* min,
2252                        float* max) {
2253   const int postamble_start = RoundDownVectors<kFloatValuesPerNeonVector>(size);
2254   float rmin = 0.0f, rmax = 0.0f;
2255   int i = 0;
2256   if (postamble_start) {
2257     float32x4_t min_f32x4 = vld1q_f32(values);
2258     float32x4_t max_f32x4 = min_f32x4;
2259     for (i = kFloatValuesPerNeonVector; i < postamble_start;
2260          i += kFloatValuesPerNeonVector) {
2261       const float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
2262       min_f32x4 = vminq_f32(min_f32x4, value0_f32x4);
2263       max_f32x4 = vmaxq_f32(max_f32x4, value0_f32x4);
2264     }
2265 #ifdef __aarch64__
2266     rmin = std::min(rmin, vminvq_f32(min_f32x4));
2267     rmax = std::max(rmax, vmaxvq_f32(max_f32x4));
2268 #else
2269     float32x2_t min_f32x2 =
2270         vmin_f32(vget_low_f32(min_f32x4), vget_high_f32(min_f32x4));
2271     float32x2_t max_f32x2 =
2272         vmax_f32(vget_low_f32(max_f32x4), vget_high_f32(max_f32x4));
2273     min_f32x2 = vpmin_f32(min_f32x2, min_f32x2);
2274     max_f32x2 = vpmax_f32(max_f32x2, max_f32x2);
2275     rmin = std::min(rmin, vget_lane_f32(min_f32x2, 0));
2276     rmax = std::max(rmax, vget_lane_f32(max_f32x2, 0));
2277 #endif  // __aarch64__
2278   }
2279   if (TFLITE_UNLIKELY(i < size)) {
2280     const auto minmax =
2281         std::minmax_element(values + postamble_start, values + size);
2282     rmin = std::min(rmin, *minmax.first);
2283     rmax = std::max(rmax, *minmax.second);
2284   }
2285   *min = rmin;
2286   *max = rmax;
2287 }
2288 
NeonSymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * min,float * max,float * scaling_factor)2289 void NeonSymmetricQuantizeFloats(const float* values, const int size,
2290                                  int8_t* quantized_values, float* min,
2291                                  float* max, float* scaling_factor) {
2292   // TODO(raziel): vectorize min/max calculation.
2293   auto minmax = std::minmax_element(values, values + size);
2294   *min = *minmax.first;
2295   *max = *minmax.second;
2296   NeonSymmetricQuantizeFloats(values, size, quantized_values, *min, *max,
2297                               scaling_factor);
2298 }
2299 
NeonSymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float min,float max,float * scaling_factor)2300 void NeonSymmetricQuantizeFloats(const float* values, const int size,
2301                                  int8_t* quantized_values, float min, float max,
2302                                  float* scaling_factor) {
2303   constexpr int kScale = 127;
2304   const float range = std::max(std::abs(min), std::abs(max));
2305   if (range == 0) {
2306     memset(quantized_values, 0, size * sizeof(int8_t));
2307     *scaling_factor = 1;
2308     return;
2309   }
2310   *scaling_factor = range / kScale;
2311   const float scaling_factor_inv = kScale / range;
2312 
2313   const int postamble_start =
2314       RoundDownVectors<(2 * kFloatValuesPerNeonVector)>(size);
2315 
2316   // Vectorized constants.
2317   const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
2318   const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
2319   const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
2320 
2321   int i = 0;
2322   for (; i < postamble_start; i += 2 * kFloatValuesPerNeonVector) {
2323     // Implements the vectorized version of the following:
2324     // const int32 quantized_value = static_cast<int32>(
2325     //    std::round(*scaling_factor * values[i]));
2326     float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
2327     float32x4_t value1_f32x4 =
2328         vld1q_f32(&values[i + kFloatValuesPerNeonVector]);
2329     float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
2330     float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
2331 
2332     const int32x4_t f2i0_i32x4 = RoundToNearest(mul0_f32x4);
2333     const int32x4_t f2i1_i32x4 = RoundToNearest(mul1_f32x4);
2334 
2335     // Implements the vectorized version of the following block:
2336     //  quantized_values[i] = std::min(kScale, std::max(-kScale,
2337     //  quantized_value));
2338     int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4);
2339     int32x4_t max1_i32x4 = vmaxq_s32(f2i1_i32x4, neg_scale_i32x4);
2340     int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
2341     int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
2342 
2343     int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
2344     int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
2345 
2346     int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
2347     int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
2348     vst1_s8(&quantized_values[i], min_s8x8);
2349   }
2350 
2351   for (; TFLITE_UNLIKELY(i < size); ++i) {
2352     const int32 quantized_value =
2353         static_cast<int32>(TfLiteRound(scaling_factor_inv * values[i]));
2354     quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
2355   }
2356 }
2357 
NeonAsymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * scaling_factor,int32_t * offset)2358 void NeonAsymmetricQuantizeFloats(const float* values, const int size,
2359                                   int8_t* quantized_values,
2360                                   float* scaling_factor, int32_t* offset) {
2361   float rmin, rmax;
2362   NeonMinMax(values, size, &rmin, &rmax);
2363 
2364   const int32_t kMinScale = -128;
2365   const int32_t kMaxScale = 127;
2366   const double qmin_double = kMinScale;
2367   const double qmax_double = kMaxScale;
2368   if (rmin == rmax) {
2369     memset(quantized_values, 0, size * sizeof(int8_t));
2370     *scaling_factor = 1;
2371     *offset = 0;
2372     return;
2373   } else {
2374     const double scale = (rmax - rmin) / (qmax_double - qmin_double);
2375     const double zero_point_from_min = qmin_double - rmin / scale;
2376     const double zero_point_from_max = qmax_double - rmax / scale;
2377     const double zero_point_from_min_error =
2378         std::abs(qmin_double) + std::abs(rmin / scale);
2379     const double zero_point_from_max_error =
2380         std::abs(qmax_double) + std::abs(rmax / scale);
2381     const double zero_point_double =
2382         zero_point_from_min_error < zero_point_from_max_error
2383             ? zero_point_from_min
2384             : zero_point_from_max;
2385     int8 nudged_zero_point = 0;
2386     if (zero_point_double <= qmin_double) {
2387       nudged_zero_point = kMinScale;
2388     } else if (zero_point_double >= qmax_double) {
2389       nudged_zero_point = kMaxScale;
2390     } else {
2391       nudged_zero_point = static_cast<int8>(round(zero_point_double));
2392     }
2393     *scaling_factor = scale;
2394     *offset = nudged_zero_point;
2395   }
2396 
2397   const int postamble_start =
2398       RoundDownVectors<(2 * kFloatValuesPerNeonVector)>(size);
2399   const float scaling_factor_inv =
2400       *scaling_factor == 0 ? 0 : 1.0 / *scaling_factor;
2401   const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
2402   const int32x4_t scale_i32x4 = vmovq_n_s32(kMaxScale);
2403   const int32x4_t neg_scale_i32x4 = vmovq_n_s32(kMinScale);
2404   const int32x4_t offset_i32x4 = vmovq_n_s32(*offset);
2405 
2406   int i = 0;
2407   for (; i < postamble_start; i += 2 * kFloatValuesPerNeonVector) {
2408     float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
2409     float32x4_t value1_f32x4 =
2410         vld1q_f32(&values[i + kFloatValuesPerNeonVector]);
2411     float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
2412     float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
2413 
2414     const int32x4_t f2i0_i32x4 = RoundToNearest(mul0_f32x4);
2415     const int32x4_t f2i1_i32x4 = RoundToNearest(mul1_f32x4);
2416 
2417     // Add offset
2418     int32x4_t q0_i32x4 = vaddq_s32(f2i0_i32x4, offset_i32x4);
2419     int32x4_t q1_i32x4 = vaddq_s32(f2i1_i32x4, offset_i32x4);
2420 
2421     int32x4_t max0_i32x4 = vmaxq_s32(q0_i32x4, neg_scale_i32x4);
2422     int32x4_t max1_i32x4 = vmaxq_s32(q1_i32x4, neg_scale_i32x4);
2423     int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
2424     int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
2425 
2426     int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
2427     int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
2428 
2429     int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
2430     int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
2431     vst1_s8(&quantized_values[i], min_s8x8);
2432   }
2433 
2434   for (; TFLITE_UNLIKELY(i < size); ++i) {
2435     const int32 quantized_value = static_cast<int32>(
2436         *offset + TfLiteRound(scaling_factor_inv * values[i]));
2437     quantized_values[i] =
2438         std::min(kMaxScale, std::max(kMinScale, quantized_value));
2439   }
2440 }
2441 
NeonVectorVectorDotProduct(const float * vector1,const float * vector2,int v_size)2442 float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
2443                                  int v_size) {
2444   // If v_size is not divisible by the vector size, then we need to process the
2445   // final few elements sequentially. postamble_start shows the start index
2446   // where this should happen.
2447   const int postamble_start =
2448       RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
2449   float32x4_t acc_32x4 = vmovq_n_f32(0.0);
2450   int v = 0;
2451   for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
2452     // Load 4 float values from vector1 and vector2 and accumulator.
2453     float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
2454     float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
2455     // Vector multiply-accumulate 4 float
2456     acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4);
2457   }
2458   float result = AccumulateNeonLane(acc_32x4);
2459   // Postamble loop.
2460   for (; TFLITE_UNLIKELY(v < v_size); v++) {
2461     result += vector1[v] * vector2[v];
2462   }
2463   return result;
2464 }
2465 
NeonReductionSumVector(const float * input_vector,float * output_vector,int output_size,int reduction_size)2466 void NeonReductionSumVector(const float* input_vector, float* output_vector,
2467                             int output_size, int reduction_size) {
2468   for (int o = 0; o < output_size; o++) {
2469     // If v_size is not divisible by the vector size, then we need to process
2470     // the final few elements sequentially. postamble_start shows the start
2471     // index where this should happen.
2472     const int postamble_start =
2473         RoundDownVectors<kFloatValuesPerNeonVector>(reduction_size);
2474     float32x4_t sum_f32x4 = vmovq_n_f32(0.0);
2475     int r = 0;
2476     for (; r < postamble_start; r += kFloatValuesPerNeonVector) {
2477       float32x4_t v1_f32x4 = vld1q_f32(input_vector + r);
2478       sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4);
2479     }
2480     float sum = AccumulateNeonLane(sum_f32x4);
2481     // Postamble loop.
2482     for (; TFLITE_UNLIKELY(r < reduction_size); r++) {
2483       sum += input_vector[r];
2484     }
2485     output_vector[o] = sum;
2486     input_vector += reduction_size;
2487   }
2488 }
2489 
NeonReductionSumVector(const int8_t * input_vector,int32_t * output_vector,const int output_size,const int reduction_size)2490 void NeonReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
2491                             const int output_size, const int reduction_size) {
2492   const int postamble_half_start =
2493       RoundDownVectors<kInt8ValuesPerNeonVector>(reduction_size);
2494   const int postamble_start =
2495       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(reduction_size);
2496   for (int o = 0; o < output_size; ++o) {
2497     int32x4_t sum_32x4 = vmovq_n_s32(0);
2498     int r = 0;
2499     for (; r < postamble_half_start; r += kInt8ValuesPerNeonVector) {
2500       const int8x16_t s2_8x16 = vld1q_s8(input_vector + r);
2501       sum_32x4 = vpadalq_s16(sum_32x4, vpaddlq_s8(s2_8x16));
2502     }
2503     if (TFLITE_UNLIKELY(r < postamble_start)) {
2504       const int8x8_t s2_8x8 = vld1_s8(input_vector + r);
2505       sum_32x4 = vpadalq_s16(sum_32x4, vmovl_s8(s2_8x8));
2506       r += (kInt8ValuesPerNeonVector >> 1);
2507     }
2508     int32_t sum = AccumulateNeonLane(sum_32x4);
2509     for (; TFLITE_UNLIKELY(r < reduction_size); ++r) {
2510       sum += input_vector[r];
2511     }
2512     output_vector[o] = sum;
2513     input_vector += reduction_size;
2514   }
2515 }
2516 
NeonVectorBatchVectorCwiseProductAccumulate(const int16_t * vector,int v_size,const int16_t * batch_vector,int n_batch,int32_t multiplier,int shift,int16_t * result)2517 void NeonVectorBatchVectorCwiseProductAccumulate(
2518     const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
2519     int32_t multiplier, int shift, int16_t* result) {
2520   int32x4_t min_value_vector = vdupq_n_s32(-32768);
2521   int32x4_t max_value_vector = vdupq_n_s32(32767);
2522 
2523   for (int b = 0; b < n_batch; b++) {
2524     int v = 0;
2525     for (; v <= v_size - 16; v += 16) {
2526       int32x4x4_t prod;
2527       prod.val[0] = vmull_s16(vld1_s16(vector + v), vld1_s16(batch_vector));
2528       prod.val[1] =
2529           vmull_s16(vld1_s16(vector + v + 4), vld1_s16(batch_vector + 4));
2530       prod.val[2] =
2531           vmull_s16(vld1_s16(vector + v + 8), vld1_s16(batch_vector + 8));
2532       prod.val[3] =
2533           vmull_s16(vld1_s16(vector + v + 12), vld1_s16(batch_vector + 12));
2534       batch_vector += 16;
2535 
2536       prod = MultiplyByQuantizedMultiplier4Rows(prod, multiplier, shift);
2537 
2538       int16x4x4_t results;
2539       results.val[0] = vld1_s16(result);
2540       results.val[1] = vld1_s16(result + 4);
2541       results.val[2] = vld1_s16(result + 8);
2542       results.val[3] = vld1_s16(result + 12);
2543 
2544       prod.val[0] = vaddq_s32(prod.val[0], vmovl_s16(results.val[0]));
2545       prod.val[1] = vaddq_s32(prod.val[1], vmovl_s16(results.val[1]));
2546       prod.val[2] = vaddq_s32(prod.val[2], vmovl_s16(results.val[2]));
2547       prod.val[3] = vaddq_s32(prod.val[3], vmovl_s16(results.val[3]));
2548 
2549       prod.val[0] = vmaxq_s32(prod.val[0], min_value_vector);
2550       prod.val[1] = vmaxq_s32(prod.val[1], min_value_vector);
2551       prod.val[2] = vmaxq_s32(prod.val[2], min_value_vector);
2552       prod.val[3] = vmaxq_s32(prod.val[3], min_value_vector);
2553 
2554       prod.val[0] = vminq_s32(prod.val[0], max_value_vector);
2555       prod.val[1] = vminq_s32(prod.val[1], max_value_vector);
2556       prod.val[2] = vminq_s32(prod.val[2], max_value_vector);
2557       prod.val[3] = vminq_s32(prod.val[3], max_value_vector);
2558 
2559       vst1_s16(result, vmovn_s32(prod.val[0]));
2560       vst1_s16(result + 4, vmovn_s32(prod.val[1]));
2561       vst1_s16(result + 8, vmovn_s32(prod.val[2]));
2562       vst1_s16(result + 12, vmovn_s32(prod.val[3]));
2563 
2564       result += 16;
2565     }
2566 
2567     for (; TFLITE_UNLIKELY(v < v_size); v++) {
2568       int32_t prod = vector[v] * *batch_vector++;
2569       prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift);
2570       int32_t output = prod + *result;
2571       output = std::max(std::min(32767, output), -32768);
2572       *result++ = output;
2573     }
2574   }
2575 }
2576 
NeonMeanStddevNormalization(const float * __restrict__ input_vector,float * __restrict__ output_vector,int v_size,int n_batch)2577 void NeonMeanStddevNormalization(const float* __restrict__ input_vector,
2578                                  float* __restrict__ output_vector, int v_size,
2579                                  int n_batch) {
2580   constexpr int kBlockSize = kFloatValuesPerNeonVector * 4;
2581 
2582   for (int batch = 0; batch < n_batch; ++batch) {
2583     // Calculate sum
2584     float32x4_t sum_f32x4_0 = vdupq_n_f32(0.0f);
2585     float32x4_t sum_f32x4_1 = vdupq_n_f32(0.0f);
2586     float32x4_t sum_f32x4_2 = vdupq_n_f32(0.0f);
2587     float32x4_t sum_f32x4_3 = vdupq_n_f32(0.0f);
2588     int i = 0;
2589     for (; i <= v_size - kBlockSize; i += kBlockSize) {
2590       const float32x4_t input_f32x4_0 =
2591           vld1q_f32(input_vector + i + 0 * kFloatValuesPerNeonVector);
2592       const float32x4_t input_f32x4_1 =
2593           vld1q_f32(input_vector + i + 1 * kFloatValuesPerNeonVector);
2594       const float32x4_t input_f32x4_2 =
2595           vld1q_f32(input_vector + i + 2 * kFloatValuesPerNeonVector);
2596       const float32x4_t input_f32x4_3 =
2597           vld1q_f32(input_vector + i + 3 * kFloatValuesPerNeonVector);
2598       sum_f32x4_0 = vaddq_f32(sum_f32x4_0, input_f32x4_0);
2599       sum_f32x4_1 = vaddq_f32(sum_f32x4_1, input_f32x4_1);
2600       sum_f32x4_2 = vaddq_f32(sum_f32x4_2, input_f32x4_2);
2601       sum_f32x4_3 = vaddq_f32(sum_f32x4_3, input_f32x4_3);
2602     }
2603     sum_f32x4_0 = vaddq_f32(sum_f32x4_0, sum_f32x4_2);
2604     sum_f32x4_1 = vaddq_f32(sum_f32x4_1, sum_f32x4_3);
2605     sum_f32x4_0 = vaddq_f32(sum_f32x4_0, sum_f32x4_1);
2606     float sum = AccumulateNeonLane(sum_f32x4_0);
2607     for (; TFLITE_UNLIKELY(i < v_size); ++i) {
2608       sum += input_vector[i];
2609     }
2610     // Calculate mean
2611     const float mean = sum / v_size;
2612     const float32x4_t mean_f32x4 = vdupq_n_f32(mean);
2613     // Calculate sum of squared differences
2614     float32x4_t sum_diff_sq_f32x4_0 = vdupq_n_f32(0.0f);
2615     float32x4_t sum_diff_sq_f32x4_1 = vdupq_n_f32(0.0f);
2616     float32x4_t sum_diff_sq_f32x4_2 = vdupq_n_f32(0.0f);
2617     float32x4_t sum_diff_sq_f32x4_3 = vdupq_n_f32(0.0f);
2618     i = 0;
2619     for (; i <= v_size - kBlockSize; i += kBlockSize) {
2620       const float32x4_t input_f32x4_0 =
2621           vld1q_f32(input_vector + i + 0 * kFloatValuesPerNeonVector);
2622       const float32x4_t input_f32x4_1 =
2623           vld1q_f32(input_vector + i + 1 * kFloatValuesPerNeonVector);
2624       const float32x4_t input_f32x4_2 =
2625           vld1q_f32(input_vector + i + 2 * kFloatValuesPerNeonVector);
2626       const float32x4_t input_f32x4_3 =
2627           vld1q_f32(input_vector + i + 3 * kFloatValuesPerNeonVector);
2628       const float32x4_t diff_f32x4_0 = vsubq_f32(input_f32x4_0, mean_f32x4);
2629       const float32x4_t diff_f32x4_1 = vsubq_f32(input_f32x4_1, mean_f32x4);
2630       const float32x4_t diff_f32x4_2 = vsubq_f32(input_f32x4_2, mean_f32x4);
2631       const float32x4_t diff_f32x4_3 = vsubq_f32(input_f32x4_3, mean_f32x4);
2632       sum_diff_sq_f32x4_0 =
2633           vmlaq_f32(sum_diff_sq_f32x4_0, diff_f32x4_0, diff_f32x4_0);
2634       sum_diff_sq_f32x4_1 =
2635           vmlaq_f32(sum_diff_sq_f32x4_1, diff_f32x4_1, diff_f32x4_1);
2636       sum_diff_sq_f32x4_2 =
2637           vmlaq_f32(sum_diff_sq_f32x4_2, diff_f32x4_2, diff_f32x4_2);
2638       sum_diff_sq_f32x4_3 =
2639           vmlaq_f32(sum_diff_sq_f32x4_3, diff_f32x4_3, diff_f32x4_3);
2640     }
2641     sum_diff_sq_f32x4_0 = vaddq_f32(sum_diff_sq_f32x4_0, sum_diff_sq_f32x4_2);
2642     sum_diff_sq_f32x4_1 = vaddq_f32(sum_diff_sq_f32x4_1, sum_diff_sq_f32x4_3);
2643     sum_diff_sq_f32x4_0 = vaddq_f32(sum_diff_sq_f32x4_0, sum_diff_sq_f32x4_1);
2644     float sum_diff_sq = AccumulateNeonLane(sum_diff_sq_f32x4_0);
2645     for (; TFLITE_UNLIKELY(i < v_size); ++i) {
2646       const float diff = input_vector[i] - mean;
2647       sum_diff_sq += diff * diff;
2648     }
2649     // Calculate 1/stddev
2650     const float variance = sum_diff_sq / v_size;
2651     constexpr float kNormalizationConstant = 1e-8f;
2652     const float stddev_inv =
2653         1.0f / std::sqrt(variance + kNormalizationConstant);
2654     // Do the normalization
2655     i = 0;
2656     for (; i <= v_size - kBlockSize; i += kBlockSize) {
2657       const float32x4_t input_f32x4_0 =
2658           vld1q_f32(input_vector + i + 0 * kFloatValuesPerNeonVector);
2659       const float32x4_t input_f32x4_1 =
2660           vld1q_f32(input_vector + i + 1 * kFloatValuesPerNeonVector);
2661       const float32x4_t input_f32x4_2 =
2662           vld1q_f32(input_vector + i + 2 * kFloatValuesPerNeonVector);
2663       const float32x4_t input_f32x4_3 =
2664           vld1q_f32(input_vector + i + 3 * kFloatValuesPerNeonVector);
2665       const float32x4_t tmp_0 = vsubq_f32(input_f32x4_0, mean_f32x4);
2666       const float32x4_t tmp_1 = vsubq_f32(input_f32x4_1, mean_f32x4);
2667       const float32x4_t tmp_2 = vsubq_f32(input_f32x4_2, mean_f32x4);
2668       const float32x4_t tmp_3 = vsubq_f32(input_f32x4_3, mean_f32x4);
2669       const float32x4_t output_f32x4_0 = vmulq_n_f32(tmp_0, stddev_inv);
2670       const float32x4_t output_f32x4_1 = vmulq_n_f32(tmp_1, stddev_inv);
2671       const float32x4_t output_f32x4_2 = vmulq_n_f32(tmp_2, stddev_inv);
2672       const float32x4_t output_f32x4_3 = vmulq_n_f32(tmp_3, stddev_inv);
2673       vst1q_f32(output_vector + i + 0 * kFloatValuesPerNeonVector,
2674                 output_f32x4_0);
2675       vst1q_f32(output_vector + i + 1 * kFloatValuesPerNeonVector,
2676                 output_f32x4_1);
2677       vst1q_f32(output_vector + i + 2 * kFloatValuesPerNeonVector,
2678                 output_f32x4_2);
2679       vst1q_f32(output_vector + i + 3 * kFloatValuesPerNeonVector,
2680                 output_f32x4_3);
2681     }
2682     for (; TFLITE_UNLIKELY(i < v_size); ++i) {
2683       output_vector[i] = (input_vector[i] - mean) * stddev_inv;
2684     }
2685     // Advance to next batch
2686     input_vector += v_size;
2687     output_vector += v_size;
2688   }
2689 }
2690 
2691 }  // namespace tensor_utils
2692 }  // namespace tflite
2693 
2694 #endif  // USE_NEON
2695