1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17 
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <cstdint>
25 #include <limits>
26 #include <memory>
27 #include <tuple>
28 #include <type_traits>
29 
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/reference/add.h"
33 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
34 
35 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
36 #include <Accelerate/Accelerate.h>
37 #endif
38 
39 #include "Eigen/Core"
40 #include "fixedpoint/fixedpoint.h"
41 #include "ruy/profiler/instrumentation.h"  // from @ruy
42 #include "tensorflow/lite/c/common.h"
43 #include "tensorflow/lite/kernels/cpu_backend_context.h"
44 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
45 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
46 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
47 #include "tensorflow/lite/kernels/internal/cppmath.h"
48 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
49 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
50 #include "tensorflow/lite/kernels/internal/quantization_util.h"
51 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
52 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
53 #include "tensorflow/lite/kernels/internal/tensor.h"
54 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
55 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
56 #include "tensorflow/lite/kernels/internal/types.h"
57 #include "unsupported/Eigen/CXX11/Tensor"
58 
59 #if __aarch64__ && __clang__
60 #define TFLITE_SOFTMAX_USE_UINT16_LUT
61 #endif
62 
63 namespace tflite {
64 namespace optimized_ops {
65 
66 // Unoptimized reference ops:
67 using reference_ops::ArgMax;
68 using reference_ops::ArgMinMax;
69 using reference_ops::Broadcast4DSlowGreater;
70 using reference_ops::Broadcast4DSlowGreaterEqual;
71 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
72 using reference_ops::Broadcast4DSlowGreaterWithScaling;
73 using reference_ops::Broadcast4DSlowLess;
74 using reference_ops::Broadcast4DSlowLessEqual;
75 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
76 using reference_ops::Broadcast4DSlowLessWithScaling;
77 using reference_ops::BroadcastAdd4DSlow;
78 using reference_ops::BroadcastMul4DSlow;
79 using reference_ops::BroadcastSub16POTSlow;
80 using reference_ops::BroadcastSubSlow;
81 using reference_ops::Concatenation;
82 using reference_ops::ConcatenationWithScaling;
83 using reference_ops::DepthConcatenation;
84 using reference_ops::Div;
85 using reference_ops::Elu;
86 using reference_ops::FakeQuant;
87 using reference_ops::Fill;
88 using reference_ops::Gather;
89 using reference_ops::Greater;
90 using reference_ops::GreaterEqual;
91 using reference_ops::GreaterEqualWithScaling;
92 using reference_ops::GreaterWithScaling;
93 using reference_ops::LeakyRelu;
94 using reference_ops::Less;
95 using reference_ops::LessEqual;
96 using reference_ops::LessEqualWithScaling;
97 using reference_ops::LessWithScaling;
98 using reference_ops::Mean;
99 using reference_ops::ProcessBroadcastShapes;
100 using reference_ops::RankOneSelect;
101 using reference_ops::Relu1;
102 using reference_ops::Relu6;
103 using reference_ops::ReluX;
104 using reference_ops::Round;
105 using reference_ops::Select;
106 using reference_ops::SpaceToBatchND;
107 using reference_ops::Split;
108 using reference_ops::Sub16;
109 
110 // TODO(b/80247582) Remove this constant.
111 // This will be phased out as the shifts are revised with more thought. Use of a
112 // constant enables us to track progress on this work.
113 //
114 // Used to convert from old-style shifts (right) to new-style (left).
115 static constexpr int kReverseShift = -1;
116 
117 // Make a local VectorMap typedef allowing to map a float array
118 // as a Eigen vector expression. The std::conditional here is to
119 // construct the suitable Eigen type for the constness of the
120 // data. Indeed, for const data, we need to produce
121 //    Eigen::Map<const Eigen::Matrix<float, ...>>
122 // and not the more straightforward
123 //    Eigen::Map<Eigen::Matrix<const float, ...>>
124 template <typename Scalar>
125 using VectorMap = typename std::conditional<
126     std::is_const<Scalar>::value,
127     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
128                                    Eigen::Dynamic, 1>>,
129     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
130 
131 template <typename Scalar>
MapAsVector(Scalar * data,const RuntimeShape & shape)132 VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
133   const int size = shape.FlatSize();
134   return VectorMap<Scalar>(data, size, 1);
135 }
136 
137 // Make a local VectorMap typedef allowing to map a float array
138 // as a Eigen matrix expression. The same explanation as for VectorMap
139 // above also applies here.
140 template <typename Scalar>
141 using MatrixMap = typename std::conditional<
142     std::is_const<Scalar>::value,
143     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
144                                    Eigen::Dynamic, Eigen::Dynamic>>,
145     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
146 
147 template <typename Scalar>
MapAsMatrixWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)148 MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
149                                                const RuntimeShape& shape) {
150   const int dims_count = shape.DimensionsCount();
151   const int rows = shape.Dims(dims_count - 1);
152   const int cols = FlatSizeSkipDim(shape, dims_count - 1);
153   return MatrixMap<Scalar>(data, rows, cols);
154 }
155 
156 template <typename Scalar>
MapAsMatrixWithFirstDimAsCols(Scalar * data,const RuntimeShape & shape)157 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
158                                                 const RuntimeShape& shape) {
159   const int cols = shape.Dims(0);
160   const int rows = FlatSizeSkipDim(shape, 0);
161   return MatrixMap<Scalar>(data, rows, cols);
162 }
163 
164 template <typename Scalar>
165 using ArrayMap = typename std::conditional<
166     std::is_const<Scalar>::value,
167     Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
168                                   Eigen::Dynamic, Eigen::Dynamic>>,
169     Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
170 
171 template <typename Scalar>
MapAsArrayWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)172 ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
173                                              const RuntimeShape& shape) {
174   const int dims_count = shape.DimensionsCount();
175   const int rows = shape.Dims(dims_count - 1);
176   const int cols = FlatSizeSkipDim(shape, dims_count - 1);
177   return ArrayMap<Scalar>(data, rows, cols);
178 }
179 
180 // Copied from tensorflow/core/framework/tensor_types.h
181 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
182 struct TTypes {
183   // Rank-1 tensor (vector) of scalar type T.
184   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
185                            Eigen::Aligned>
186       Flat;
187   typedef Eigen::TensorMap<
188       Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
189       UnalignedConstMatrix;
190 };
191 
192 // TODO(b/62193649): this function is only needed as long
193 // as we have the --variable_batch hack.
194 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)195 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
196                                                    const RuntimeShape& shape,
197                                                    int rows) {
198   const int flatsize = shape.FlatSize();
199   TFLITE_DCHECK_EQ(flatsize % rows, 0);
200   const int cols = flatsize / rows;
201   return MatrixMap<Scalar>(data, rows, cols);
202 }
203 
204 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
BinaryBroadcastFiveFold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const T * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const T * unswitched_input2_data,const RuntimeShape & output_shape,T * output_data,ElementwiseF elementwise_f,ScalarBroadcastF scalar_broadcast_f)205 inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
206                                     const RuntimeShape& unswitched_input1_shape,
207                                     const T* unswitched_input1_data,
208                                     const RuntimeShape& unswitched_input2_shape,
209                                     const T* unswitched_input2_data,
210                                     const RuntimeShape& output_shape,
211                                     T* output_data, ElementwiseF elementwise_f,
212                                     ScalarBroadcastF scalar_broadcast_f) {
213   ArithmeticParams switched_params = unswitched_params;
214   switched_params.input1_offset = unswitched_params.input2_offset;
215   switched_params.input1_multiplier = unswitched_params.input2_multiplier;
216   switched_params.input1_shift = unswitched_params.input2_shift;
217   switched_params.input2_offset = unswitched_params.input1_offset;
218   switched_params.input2_multiplier = unswitched_params.input1_multiplier;
219   switched_params.input2_shift = unswitched_params.input1_shift;
220 
221   const bool use_unswitched =
222       unswitched_params.broadcast_category ==
223       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
224 
225   const ArithmeticParams& params =
226       use_unswitched ? unswitched_params : switched_params;
227   const T* input1_data =
228       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
229   const T* input2_data =
230       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
231 
232   // Fivefold nested loops. The second input resets its position for each
233   // iteration of the second loop. The first input resets its position at the
234   // beginning of the fourth loop. The innermost loop is an elementwise add of
235   // sections of the arrays.
236   T* output_data_ptr = output_data;
237   const T* input1_data_ptr = input1_data;
238   const T* input2_data_reset = input2_data;
239   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
240   // between input shapes. y3 for input 1 is always broadcast, and so the
241   // dimension there is 1, whereas optionally y1 might be broadcast for
242   // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
243   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
244   int y0 = params.broadcast_shape[0];
245   int y1 = params.broadcast_shape[1];
246   int y2 = params.broadcast_shape[2];
247   int y3 = params.broadcast_shape[3];
248   int y4 = params.broadcast_shape[4];
249   if (y4 > 1) {
250     // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
251     // dimension.
252     for (int i0 = 0; i0 < y0; ++i0) {
253       const T* input2_data_ptr = nullptr;
254       for (int i1 = 0; i1 < y1; ++i1) {
255         input2_data_ptr = input2_data_reset;
256         for (int i2 = 0; i2 < y2; ++i2) {
257           for (int i3 = 0; i3 < y3; ++i3) {
258             elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
259                           output_data_ptr);
260             input2_data_ptr += y4;
261             output_data_ptr += y4;
262           }
263           // We have broadcast y4 of input1 data y3 times, and now move on.
264           input1_data_ptr += y4;
265         }
266       }
267       // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
268       input2_data_reset = input2_data_ptr;
269     }
270   } else {
271     // Special case of y4 == 1, in which the innermost loop is a single
272     // element and can be combined with the next (y3) as an inner broadcast.
273     //
274     // Note that this handles the case of pure scalar broadcast when
275     // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
276     // broadcast with batch (as y2 > 1).
277     //
278     // NOTE The process is the same as the above general case except
279     // simplified for y4 == 1 and the loop over y3 is contained within the
280     // AddScalarBroadcast function.
281     for (int i0 = 0; i0 < y0; ++i0) {
282       const T* input2_data_ptr = nullptr;
283       for (int i1 = 0; i1 < y1; ++i1) {
284         input2_data_ptr = input2_data_reset;
285         for (int i2 = 0; i2 < y2; ++i2) {
286           scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
287                              output_data_ptr);
288           input2_data_ptr += y3;
289           output_data_ptr += y3;
290           input1_data_ptr += 1;
291         }
292       }
293       input2_data_reset = input2_data_ptr;
294     }
295   }
296 }
297 
298 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
299 
300 // Looks up each element of <indices> in <table>, returns them in a vector.
aarch64_lookup_vector(const uint8x16x4_t table[4],uint8x16_t indices)301 inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
302                                         uint8x16_t indices) {
303   // Look up in 1st quarter of the table: top 2 bits of indices == 00
304   uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
305   // Look up in 2nd quarter of the table: top 2 bits of indices == 01
306   uint8x16_t output2 =
307       vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
308   // Look up in 3rd quarter of the table: top 2 bits of indices == 10
309   uint8x16_t output3 =
310       vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
311   // Look up in 4th quarter of the table: top 2 bits of indices == 11
312   uint8x16_t output4 =
313       vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
314 
315   // Combine result of the 4 lookups.
316   return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
317 }
318 
319 #endif
320 
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)321 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
322                                              float output_activation_max,
323                                              const RuntimeShape& bias_shape,
324                                              const float* bias_data,
325                                              const RuntimeShape& array_shape,
326                                              float* array_data) {
327   BiasAndClamp(output_activation_min, output_activation_max,
328                bias_shape.FlatSize(), bias_data, array_shape.FlatSize(),
329                array_data);
330 }
331 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * cpu_backend_context)332 inline void FullyConnected(
333     const FullyConnectedParams& params, const RuntimeShape& input_shape,
334     const float* input_data, const RuntimeShape& weights_shape,
335     const float* weights_data, const RuntimeShape& bias_shape,
336     const float* optional_bias_data, const RuntimeShape& output_shape,
337     float* output_data, CpuBackendContext* cpu_backend_context) {
338   ruy::profiler::ScopeLabel label("FullyConnected");
339   const int dims_count = weights_shape.DimensionsCount();
340   const int input_rows = weights_shape.Dims(dims_count - 1);
341   cpu_backend_gemm::MatrixParams<float> rhs_params;
342   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
343   rhs_params.rows = input_rows;
344   rhs_params.cols = input_shape.FlatSize() / input_rows;
345   rhs_params.cache_policy =
346       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
347   TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
348   cpu_backend_gemm::MatrixParams<float> lhs_params;
349   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
350   lhs_params.cols = weights_shape.Dims(dims_count - 1);
351   lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
352   lhs_params.cache_policy =
353       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
354   cpu_backend_gemm::MatrixParams<float> dst_params;
355   dst_params.order = cpu_backend_gemm::Order::kColMajor;
356   dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
357   dst_params.cols =
358       FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
359   cpu_backend_gemm::GemmParams<float, float> gemm_params;
360   gemm_params.bias = optional_bias_data;
361   gemm_params.clamp_min = params.float_activation_min;
362   gemm_params.clamp_max = params.float_activation_max;
363   cpu_backend_gemm::Gemm(lhs_params, weights_data, rhs_params, input_data,
364                          dst_params, output_data, gemm_params,
365                          cpu_backend_context);
366 }
367 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,CpuBackendContext * cpu_backend_context)368 inline void FullyConnected(
369     const FullyConnectedParams& params, const RuntimeShape& input_shape,
370     const uint8* input_data, const RuntimeShape& filter_shape,
371     const uint8* filter_data, const RuntimeShape& bias_shape,
372     const int32* bias_data, const RuntimeShape& output_shape,
373     uint8* output_data, CpuBackendContext* cpu_backend_context) {
374   ruy::profiler::ScopeLabel label("FullyConnected/8bit");
375   const int32 input_offset = params.input_offset;
376   const int32 filter_offset = params.weights_offset;
377   const int32 output_offset = params.output_offset;
378   const int32 output_multiplier = params.output_multiplier;
379   const int output_shift = params.output_shift;
380   const int32 output_activation_min = params.quantized_activation_min;
381   const int32 output_activation_max = params.quantized_activation_max;
382   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
383   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
384   // TODO(b/62193649): This really should be:
385   //     const int batches = ArraySize(output_dims, 1);
386   // but the current --variable_batch hack consists in overwriting the 3rd
387   // dimension with the runtime batch size, as we don't keep track for each
388   // array of which dimension is the batch dimension in it.
389   const int output_dim_count = output_shape.DimensionsCount();
390   const int filter_dim_count = filter_shape.DimensionsCount();
391   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
392   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
393   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
394   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
395   const int output_rows = output_shape.Dims(output_dim_count - 1);
396   TFLITE_DCHECK_EQ(output_rows, filter_rows);
397   if (bias_data) {
398     TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
399   }
400 
401   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
402   lhs_params.rows = filter_rows;
403   lhs_params.cols = filter_cols;
404   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
405   lhs_params.zero_point = -filter_offset;
406   lhs_params.cache_policy =
407       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
408   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
409   rhs_params.rows = filter_cols;
410   rhs_params.cols = batches;
411   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
412   rhs_params.zero_point = -input_offset;
413   rhs_params.cache_policy =
414       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
415   cpu_backend_gemm::MatrixParams<uint8> dst_params;
416   dst_params.rows = filter_rows;
417   dst_params.cols = batches;
418   dst_params.order = cpu_backend_gemm::Order::kColMajor;
419   dst_params.zero_point = output_offset;
420   cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
421   gemm_params.bias = bias_data;
422   gemm_params.clamp_min = output_activation_min;
423   gemm_params.clamp_max = output_activation_max;
424   gemm_params.multiplier_fixedpoint = output_multiplier;
425   gemm_params.multiplier_exponent = output_shift;
426   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
427                          dst_params, output_data, gemm_params,
428                          cpu_backend_context);
429 }
430 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,CpuBackendContext * cpu_backend_context)431 inline void FullyConnected(
432     const FullyConnectedParams& params, const RuntimeShape& input_shape,
433     const uint8* input_data, const RuntimeShape& filter_shape,
434     const uint8* filter_data, const RuntimeShape& bias_shape,
435     const int32* bias_data_int32, const RuntimeShape& output_shape,
436     int16* output_data, CpuBackendContext* cpu_backend_context) {
437   ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
438   const int32 input_offset = params.input_offset;
439   const int32 filter_offset = params.weights_offset;
440   const int32 output_offset = params.output_offset;
441   const int32 output_multiplier = params.output_multiplier;
442   const int output_shift = params.output_shift;
443   const int32 output_activation_min = params.quantized_activation_min;
444   const int32 output_activation_max = params.quantized_activation_max;
445   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
446   TFLITE_DCHECK_EQ(output_offset, 0);
447   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
448   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
449 
450   // TODO(b/62193649): This really should be:
451   //     const int batches = ArraySize(output_dims, 1);
452   // but the current --variable_batch hack consists in overwriting the 3rd
453   // dimension with the runtime batch size, as we don't keep track for each
454   // array of which dimension is the batch dimension in it.
455   const int output_dim_count = output_shape.DimensionsCount();
456   const int filter_dim_count = filter_shape.DimensionsCount();
457   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
458   const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
459                                        output_shape, output_dim_count - 1);
460   const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
461 
462   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
463   lhs_params.rows = output_depth;
464   lhs_params.cols = accum_depth;
465   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
466   lhs_params.zero_point = -filter_offset;
467   lhs_params.cache_policy =
468       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
469   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
470   rhs_params.rows = accum_depth;
471   rhs_params.cols = batches;
472   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
473   rhs_params.zero_point = -input_offset;
474   rhs_params.cache_policy =
475       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
476   cpu_backend_gemm::MatrixParams<int16> dst_params;
477   dst_params.rows = output_depth;
478   dst_params.cols = batches;
479   dst_params.order = cpu_backend_gemm::Order::kColMajor;
480   dst_params.zero_point = 0;
481   cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
482   gemm_params.bias = bias_data_int32;
483   gemm_params.clamp_min = output_activation_min;
484   gemm_params.clamp_max = output_activation_max;
485   gemm_params.multiplier_fixedpoint = output_multiplier;
486   gemm_params.multiplier_exponent = output_shift;
487   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
488                          dst_params, output_data, gemm_params,
489                          cpu_backend_context);
490 }
491 
492 // Internal function doing the actual arithmetic work for
493 // ShuffledFullyConnected.
494 // May be called either directly by it (single-threaded case) or may be used
495 // as the 'task' for worker threads to run (multi-threaded case, see
496 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)497 inline void ShuffledFullyConnectedWorkerImpl(
498     const uint8* shuffled_input_workspace_data,
499     const int8* shuffled_weights_data, int batches, int output_depth,
500     int output_stride, int accum_depth, const int32* bias_data,
501     int32 output_multiplier, int output_shift, int16* output_data) {
502 #if defined USE_NEON
503   const int8* shuffled_weights_ptr = shuffled_weights_data;
504   if (batches == 1) {
505     const int right_shift = output_shift > 0 ? 0 : -output_shift;
506     const int left_shift = output_shift > 0 ? output_shift : 0;
507     for (int c = 0; c < output_depth; c += 4) {
508       // Accumulation loop.
509       int32x4_t row_accum0 = vdupq_n_s32(0);
510       int32x4_t row_accum1 = vdupq_n_s32(0);
511       int32x4_t row_accum2 = vdupq_n_s32(0);
512       int32x4_t row_accum3 = vdupq_n_s32(0);
513       for (int d = 0; d < accum_depth; d += 16) {
514         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
515         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
516         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
517         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
518         shuffled_weights_ptr += 64;
519         int8x16_t input =
520             vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
521         int16x8_t local_accum0 =
522             vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
523         int16x8_t local_accum1 =
524             vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
525         int16x8_t local_accum2 =
526             vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
527         int16x8_t local_accum3 =
528             vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
529         local_accum0 =
530             vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
531         local_accum1 =
532             vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
533         local_accum2 =
534             vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
535         local_accum3 =
536             vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
537         row_accum0 = vpadalq_s16(row_accum0, local_accum0);
538         row_accum1 = vpadalq_s16(row_accum1, local_accum1);
539         row_accum2 = vpadalq_s16(row_accum2, local_accum2);
540         row_accum3 = vpadalq_s16(row_accum3, local_accum3);
541       }
542       // Horizontally reduce accumulators
543       int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
544           pairwise_reduced_acc_2, pairwise_reduced_acc_3;
545       pairwise_reduced_acc_0 =
546           vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
547       pairwise_reduced_acc_1 =
548           vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
549       pairwise_reduced_acc_2 =
550           vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
551       pairwise_reduced_acc_3 =
552           vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
553       const int32x2_t reduced_lo =
554           vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
555       const int32x2_t reduced_hi =
556           vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
557       int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
558       // Add bias values.
559       int32x4_t bias_vec = vld1q_s32(bias_data + c);
560       reduced = vaddq_s32(reduced, bias_vec);
561       reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
562       // Multiply by the fixed-point multiplier.
563       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
564       // Rounding-shift-right.
565       using gemmlowp::RoundingDivideByPOT;
566       reduced = RoundingDivideByPOT(reduced, right_shift);
567       // Narrow values down to 16 bit signed.
568       const int16x4_t res16 = vqmovn_s32(reduced);
569       vst1_s16(output_data + c, res16);
570     }
571   } else if (batches == 4) {
572     const int right_shift = output_shift > 0 ? 0 : -output_shift;
573     const int left_shift = output_shift > 0 ? output_shift : 0;
574     for (int c = 0; c < output_depth; c += 4) {
575       const int8* shuffled_input_ptr =
576           reinterpret_cast<const int8*>(shuffled_input_workspace_data);
577       // Accumulation loop.
578       int32x4_t row_accum00 = vdupq_n_s32(0);
579       int32x4_t row_accum10 = vdupq_n_s32(0);
580       int32x4_t row_accum20 = vdupq_n_s32(0);
581       int32x4_t row_accum30 = vdupq_n_s32(0);
582       int32x4_t row_accum01 = vdupq_n_s32(0);
583       int32x4_t row_accum11 = vdupq_n_s32(0);
584       int32x4_t row_accum21 = vdupq_n_s32(0);
585       int32x4_t row_accum31 = vdupq_n_s32(0);
586       int32x4_t row_accum02 = vdupq_n_s32(0);
587       int32x4_t row_accum12 = vdupq_n_s32(0);
588       int32x4_t row_accum22 = vdupq_n_s32(0);
589       int32x4_t row_accum32 = vdupq_n_s32(0);
590       int32x4_t row_accum03 = vdupq_n_s32(0);
591       int32x4_t row_accum13 = vdupq_n_s32(0);
592       int32x4_t row_accum23 = vdupq_n_s32(0);
593       int32x4_t row_accum33 = vdupq_n_s32(0);
594       for (int d = 0; d < accum_depth; d += 16) {
595         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
596         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
597         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
598         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
599         shuffled_weights_ptr += 64;
600         int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
601         int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
602         int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
603         int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
604         shuffled_input_ptr += 64;
605         int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
606 #define TFLITE_SHUFFLED_FC_ACCUM(B)                                           \
607   local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B));      \
608   local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B));      \
609   local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B));      \
610   local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B));      \
611   local_accum0 =                                                              \
612       vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
613   local_accum1 =                                                              \
614       vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
615   local_accum2 =                                                              \
616       vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
617   local_accum3 =                                                              \
618       vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
619   row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0);                   \
620   row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1);                   \
621   row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2);                   \
622   row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
623 
624         TFLITE_SHUFFLED_FC_ACCUM(0)
625         TFLITE_SHUFFLED_FC_ACCUM(1)
626         TFLITE_SHUFFLED_FC_ACCUM(2)
627         TFLITE_SHUFFLED_FC_ACCUM(3)
628 
629 #undef TFLITE_SHUFFLED_FC_ACCUM
630       }
631       // Horizontally reduce accumulators
632 
633 #define TFLITE_SHUFFLED_FC_STORE(B)                                           \
634   {                                                                           \
635     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,                 \
636         pairwise_reduced_acc_2, pairwise_reduced_acc_3;                       \
637     pairwise_reduced_acc_0 =                                                  \
638         vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
639     pairwise_reduced_acc_1 =                                                  \
640         vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
641     pairwise_reduced_acc_2 =                                                  \
642         vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
643     pairwise_reduced_acc_3 =                                                  \
644         vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
645     const int32x2_t reduced_lo =                                              \
646         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);            \
647     const int32x2_t reduced_hi =                                              \
648         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);            \
649     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);                 \
650     int32x4_t bias_vec = vld1q_s32(bias_data + c);                            \
651     reduced = vaddq_s32(reduced, bias_vec);                                   \
652     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));                    \
653     reduced = vqrdmulhq_n_s32(reduced, output_multiplier);                    \
654     using gemmlowp::RoundingDivideByPOT;                                      \
655     reduced = RoundingDivideByPOT(reduced, right_shift);                      \
656     const int16x4_t res16 = vqmovn_s32(reduced);                              \
657     vst1_s16(output_data + c + B * output_stride, res16);                     \
658   }
659 
660       TFLITE_SHUFFLED_FC_STORE(0);
661       TFLITE_SHUFFLED_FC_STORE(1);
662       TFLITE_SHUFFLED_FC_STORE(2);
663       TFLITE_SHUFFLED_FC_STORE(3);
664 
665 #undef TFLITE_SHUFFLED_FC_STORE
666     }
667   } else {
668     TFLITE_DCHECK(false);
669     return;
670   }
671 #else
672   if (batches == 1) {
673     int16* output_ptr = output_data;
674     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
675     // so that just reinterpreting them as int8 values is equivalent to
676     // subtracting 128 from them, thus implementing for free the subtraction of
677     // the zero_point value 128.
678     const int8* shuffled_weights_ptr =
679         reinterpret_cast<const int8*>(shuffled_weights_data);
680     // Likewise, we preshuffled and pre-xored the input data above.
681     const int8* shuffled_input_data =
682         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
683     for (int c = 0; c < output_depth; c += 4) {
684       // Internal accumulation.
685       // Initialize accumulator with the bias-value.
686       int32 accum[4] = {0};
687       // Accumulation loop.
688       for (int d = 0; d < accum_depth; d += 16) {
689         for (int i = 0; i < 4; i++) {
690           for (int j = 0; j < 16; j++) {
691             int8 input_val = shuffled_input_data[d + j];
692             int8 weights_val = *shuffled_weights_ptr++;
693             accum[i] += weights_val * input_val;
694           }
695         }
696       }
697       for (int i = 0; i < 4; i++) {
698         // Add bias value
699         int acc = accum[i] + bias_data[c + i];
700         // Down-scale the final int32 accumulator to the scale used by our
701         // (16-bit, typically 3 integer bits) fixed-point format. The quantized
702         // multiplier and shift here have been pre-computed offline
703         // (e.g. by toco).
704         acc =
705             MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
706         // Saturate, cast to int16, and store to output array.
707         acc = std::max(acc, -32768);
708         acc = std::min(acc, 32767);
709         output_ptr[c + i] = acc;
710       }
711     }
712   } else if (batches == 4) {
713     int16* output_ptr = output_data;
714     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
715     // so that just reinterpreting them as int8 values is equivalent to
716     // subtracting 128 from them, thus implementing for free the subtraction of
717     // the zero_point value 128.
718     const int8* shuffled_weights_ptr =
719         reinterpret_cast<const int8*>(shuffled_weights_data);
720     // Likewise, we preshuffled and pre-xored the input data above.
721     const int8* shuffled_input_data =
722         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
723     for (int c = 0; c < output_depth; c += 4) {
724       const int8* shuffled_input_ptr = shuffled_input_data;
725       // Accumulation loop.
726       // Internal accumulation.
727       // Initialize accumulator with the bias-value.
728       int32 accum[4][4];
729       for (int i = 0; i < 4; i++) {
730         for (int b = 0; b < 4; b++) {
731           accum[i][b] = 0;
732         }
733       }
734       for (int d = 0; d < accum_depth; d += 16) {
735         for (int i = 0; i < 4; i++) {
736           for (int b = 0; b < 4; b++) {
737             for (int j = 0; j < 16; j++) {
738               int8 input_val = shuffled_input_ptr[16 * b + j];
739               int8 weights_val = shuffled_weights_ptr[16 * i + j];
740               accum[i][b] += weights_val * input_val;
741             }
742           }
743         }
744         shuffled_input_ptr += 64;
745         shuffled_weights_ptr += 64;
746       }
747       for (int i = 0; i < 4; i++) {
748         for (int b = 0; b < 4; b++) {
749           // Add bias value
750           int acc = accum[i][b] + bias_data[c + i];
751           // Down-scale the final int32 accumulator to the scale used by our
752           // (16-bit, typically 3 integer bits) fixed-point format. The
753           // quantized multiplier and shift here have been pre-computed offline
754           // (e.g. by toco).
755           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
756                                               output_shift);
757           // Saturate, cast to int16, and store to output array.
758           acc = std::max(acc, -32768);
759           acc = std::min(acc, 32767);
760           output_ptr[b * output_stride + c + i] = acc;
761         }
762       }
763     }
764   } else {
765     TFLITE_DCHECK(false);
766     return;
767   }
768 #endif
769 }
770 
771 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
772 // to allow using gemmlowp's threadpool.
773 struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask774   ShuffledFullyConnectedWorkerTask(const uint8* input_data,
775                                    const int8* shuffled_weights_data,
776                                    int batches, int output_depth,
777                                    int output_stride, int accum_depth,
778                                    const int32* bias_data,
779                                    int32 output_multiplier, int output_shift,
780                                    int16* output_data)
781       : input_data_(input_data),
782         shuffled_weights_data_(shuffled_weights_data),
783         batches_(batches),
784         output_depth_(output_depth),
785         output_stride_(output_stride),
786         accum_depth_(accum_depth),
787         bias_data_(bias_data),
788         output_multiplier_(output_multiplier),
789         output_shift_(output_shift),
790         output_data_(output_data) {}
791 
RunShuffledFullyConnectedWorkerTask792   void Run() override {
793     ShuffledFullyConnectedWorkerImpl(
794         input_data_, shuffled_weights_data_, batches_, output_depth_,
795         output_stride_, accum_depth_, bias_data_, output_multiplier_,
796         output_shift_, output_data_);
797   }
798 
799   const uint8* input_data_;
800   const int8* shuffled_weights_data_;
801   int batches_;
802   int output_depth_;
803   int output_stride_;
804   int accum_depth_;
805   const int32* bias_data_;
806   int32 output_multiplier_;
807   int output_shift_;
808   int16* output_data_;
809 };
810 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,CpuBackendContext * cpu_backend_context)811 inline void ShuffledFullyConnected(
812     const FullyConnectedParams& params, const RuntimeShape& input_shape,
813     const uint8* input_data, const RuntimeShape& weights_shape,
814     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
815     const int32* bias_data, const RuntimeShape& output_shape,
816     int16* output_data, uint8* shuffled_input_workspace_data,
817     CpuBackendContext* cpu_backend_context) {
818   ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
819   const int32 output_multiplier = params.output_multiplier;
820   const int output_shift = params.output_shift;
821   const int32 output_activation_min = params.quantized_activation_min;
822   const int32 output_activation_max = params.quantized_activation_max;
823   TFLITE_DCHECK_EQ(output_activation_min, -32768);
824   TFLITE_DCHECK_EQ(output_activation_max, 32767);
825   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
826   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
827   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
828   // TODO(b/62193649): This really should be:
829   //     const int batches = ArraySize(output_dims, 1);
830   // but the current --variable_batch hack consists in overwriting the 3rd
831   // dimension with the runtime batch size, as we don't keep track for each
832   // array of which dimension is the batch dimension in it.
833   const int output_dim_count = output_shape.DimensionsCount();
834   const int weights_dim_count = weights_shape.DimensionsCount();
835   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
836   const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
837                                        output_shape, output_dim_count - 1);
838   const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
839   TFLITE_DCHECK((accum_depth % 16) == 0);
840   TFLITE_DCHECK((output_depth % 4) == 0);
841   // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
842   // so that just reinterpreting them as int8 values is equivalent to
843   // subtracting 128 from them, thus implementing for free the subtraction of
844   // the zero_point value 128.
845   const int8* int8_shuffled_weights_data =
846       reinterpret_cast<const int8*>(shuffled_weights_data);
847 
848   // Shuffling and xoring of input activations into the workspace buffer
849   if (batches == 1) {
850 #ifdef USE_NEON
851     const uint8x16_t signbit = vdupq_n_u8(0x80);
852     for (int i = 0; i < accum_depth; i += 16) {
853       uint8x16_t val = vld1q_u8(input_data + i);
854       val = veorq_u8(val, signbit);
855       vst1q_u8(shuffled_input_workspace_data + i, val);
856     }
857 #else
858     for (int i = 0; i < accum_depth; i++) {
859       shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
860     }
861 #endif
862   } else if (batches == 4) {
863     uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
864     int c = 0;
865 #ifdef USE_NEON
866     const uint8x16_t signbit = vdupq_n_u8(0x80);
867     for (c = 0; c < accum_depth; c += 16) {
868       const uint8* src_data_ptr = input_data + c;
869       uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
870       uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
871       uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
872       uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
873       val0 = veorq_u8(val0, signbit);
874       val1 = veorq_u8(val1, signbit);
875       val2 = veorq_u8(val2, signbit);
876       val3 = veorq_u8(val3, signbit);
877       vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
878       vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
879       vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
880       vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
881       shuffled_input_workspace_ptr += 64;
882     }
883 #else
884     for (c = 0; c < accum_depth; c += 16) {
885       for (int b = 0; b < 4; b++) {
886         const uint8* src_data_ptr = input_data + b * accum_depth + c;
887         for (int j = 0; j < 16; j++) {
888           uint8 src_val = *src_data_ptr++;
889           // Flip the sign bit, so that the kernel will only need to
890           // reinterpret these uint8 values as int8, getting for free the
891           // subtraction of the zero_point value 128.
892           uint8 dst_val = src_val ^ 0x80;
893           *shuffled_input_workspace_ptr++ = dst_val;
894         }
895       }
896     }
897 #endif
898   } else {
899     TFLITE_DCHECK(false);
900     return;
901   }
902 
903   static constexpr int kKernelRows = 4;
904   const int thread_count =
905       LegacyHowManyThreads<kKernelRows>(cpu_backend_context->max_num_threads(),
906                                         output_depth, batches, accum_depth);
907   if (thread_count == 1) {
908     // Single-thread case: do the computation on the current thread, don't
909     // use a threadpool
910     ShuffledFullyConnectedWorkerImpl(
911         shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
912         output_depth, output_depth, accum_depth, bias_data, output_multiplier,
913         output_shift, output_data);
914     return;
915   }
916 
917   // Multi-threaded case: use the gemmlowp context's threadpool.
918   TFLITE_DCHECK_GT(thread_count, 1);
919   std::vector<ShuffledFullyConnectedWorkerTask> tasks;
920   // TODO(b/131746020) don't create new heap allocations every time.
921   // At least we make it a single heap allocation by using reserve().
922   tasks.reserve(thread_count);
923   const int kRowsPerWorker =
924       RoundUp<kKernelRows>(CeilQuotient(output_depth, thread_count));
925   int row_start = 0;
926   for (int i = 0; i < thread_count; i++) {
927     int row_end = std::min(output_depth, row_start + kRowsPerWorker);
928     tasks.emplace_back(shuffled_input_workspace_data,
929                        int8_shuffled_weights_data + row_start * accum_depth,
930                        batches, row_end - row_start, output_depth, accum_depth,
931                        bias_data + row_start, output_multiplier, output_shift,
932                        output_data + row_start);
933     row_start = row_end;
934   }
935   TFLITE_DCHECK_EQ(row_start, output_depth);
936   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
937                                   cpu_backend_context);
938 }
939 
940 #ifdef USE_NEON
941 
RoundToNearest(const float32x4_t input)942 inline int32x4_t RoundToNearest(const float32x4_t input) {
943 #if defined(__aarch64__) || defined(__SSSE3__)
944   // Note: vcvtnq_s32_f32 is not available in ARMv7
945   return vcvtnq_s32_f32(input);
946 #else
947   static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
948   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
949   static const float32x4_t minus_point5_val_dup = vdupq_n_f32(-0.5f);
950 
951   const uint32x4_t mask = vcltq_f32(input, zero_val_dup);
952   const float32x4_t round =
953       vbslq_f32(mask, minus_point5_val_dup, point5_val_dup);
954   return vcvtq_s32_f32(vaddq_f32(input, round));
955 #endif  // defined(__aarch64__) || defined(__SSSE3__)
956 }
957 
RoundToNearestUnsigned(const float32x4_t input)958 inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
959 #if defined(__aarch64__)
960   // Note that vcvtnq_u32_f32 is not available in ARMv7 or in arm_neon_sse.h.
961   return vcvtnq_u32_f32(input);
962 #else
963   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
964 
965   return vcvtq_u32_f32(vaddq_f32(input, point5_val_dup));
966 #endif  // defined(__aarch64__)
967 }
968 
969 #endif  // USE_NEON
970 
MeanImpl(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,int32 multiplier,int32 shift,int32 bias,const RuntimeShape & output_shape,uint8_t * output_data,int start_depth,int end_depth)971 inline void MeanImpl(const tflite::MeanParams& op_params,
972                      const RuntimeShape& input_shape, const uint8_t* input_data,
973                      int32 multiplier, int32 shift, int32 bias,
974                      const RuntimeShape& output_shape, uint8_t* output_data,
975                      int start_depth, int end_depth) {
976   ruy::profiler::ScopeLabel label("Mean4D/Uint8/MeanImpl");
977 
978   // Current implementation only supports dimension equals 4 and simultaneous
979   // reduction over width and height.
980   const int output_batch = output_shape.Dims(0);
981   const int output_height = output_shape.Dims(2);
982   const int output_width = output_shape.Dims(2);
983   const int input_height = input_shape.Dims(1);
984   const int input_width = input_shape.Dims(2);
985 
986   TFLITE_CHECK_EQ(op_params.axis_count, 2);
987   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
988                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
989   TFLITE_CHECK_EQ(output_height, 1);
990   TFLITE_CHECK_EQ(output_width, 1);
991 
992   constexpr int32_t kMinValue = std::numeric_limits<uint8_t>::min();
993   constexpr int32_t kMaxValue = std::numeric_limits<uint8_t>::max();
994 
995 #ifdef USE_NEON
996   const int32x4_t bias_dup = vdupq_n_s32(bias);
997   const int32x4_t min_dup = vdupq_n_s32(kMinValue);
998   const int32x4_t max_dup = vdupq_n_s32(kMaxValue);
999 #endif  // USE_NEON
1000 
1001   for (int out_b = 0; out_b < output_batch; ++out_b) {
1002     int out_d = start_depth;
1003 #ifdef USE_NEON
1004 
1005     for (; out_d <= end_depth - 16; out_d += 16) {
1006       int32x4x4_t temp_sum;
1007       temp_sum.val[0] = vdupq_n_s32(0);
1008       temp_sum.val[1] = vdupq_n_s32(0);
1009       temp_sum.val[2] = vdupq_n_s32(0);
1010       temp_sum.val[3] = vdupq_n_s32(0);
1011       for (int in_h = 0; in_h < input_height; ++in_h) {
1012         for (int in_w = 0; in_w < input_width; ++in_w) {
1013           const uint8_t* input_data_ptr =
1014               input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
1015           uint8x16_t input_data_val = vld1q_u8(input_data_ptr);
1016 
1017           int16x8_t input_data_low_shift =
1018               vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_data_val)));
1019           int16x8_t input_data_high_shift =
1020               vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_data_val)));
1021 
1022           int32x4_t input_low_low =
1023               vmovl_s16(vget_low_s16(input_data_low_shift));
1024           int32x4_t input_high_low =
1025               vmovl_s16(vget_high_s16(input_data_low_shift));
1026           int32x4_t input_low_high =
1027               vmovl_s16(vget_low_s16(input_data_high_shift));
1028           int32x4_t input_high_high =
1029               vmovl_s16(vget_high_s16(input_data_high_shift));
1030 
1031           temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low);
1032           temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low);
1033           temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high);
1034           temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high);
1035         }
1036       }
1037 
1038       temp_sum =
1039           MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift);
1040 
1041       temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
1042       temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
1043       temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup);
1044       temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup);
1045 
1046       temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup);
1047       temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup);
1048       temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup);
1049       temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup);
1050 
1051       uint16x4_t narrowed_low_low =
1052           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[0]));
1053       uint16x4_t narrowed_high_low =
1054           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[1]));
1055       uint16x4_t narrowed_low_high =
1056           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[2]));
1057       uint16x4_t narrowed_high_high =
1058           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3]));
1059 
1060       uint16x8_t combined_low =
1061           vcombine_u16(narrowed_low_low, narrowed_high_low);
1062       uint16x8_t combined_high =
1063           vcombine_u16(narrowed_low_high, narrowed_high_high);
1064 
1065       uint8x8_t narrowed_low = vmovn_u16(combined_low);
1066       uint8x8_t narrowed_high = vmovn_u16(combined_high);
1067 
1068       uint8x16_t combined_output = vcombine_u8(narrowed_low, narrowed_high);
1069 
1070       uint8_t* output_data_ptr =
1071           output_data + Offset(output_shape, out_b, 0, 0, out_d);
1072       vst1q_u8(output_data_ptr, combined_output);
1073     }
1074 #endif  // USE_NEON
1075 
1076     for (; out_d < end_depth; ++out_d) {
1077       int acc = 0;
1078       for (int in_h = 0; in_h < input_height; ++in_h) {
1079         for (int in_w = 0; in_w < input_width; ++in_w) {
1080           acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
1081         }
1082       }
1083 
1084       acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
1085       acc += bias;
1086       acc = std::min(std::max(acc, kMinValue), kMaxValue);
1087       output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1088           static_cast<uint8_t>(acc);
1089     }
1090   }
1091 }
1092 
1093 struct MeanWorkerTask : cpu_backend_threadpool::Task {
MeanWorkerTaskMeanWorkerTask1094   MeanWorkerTask(const tflite::MeanParams& op_params,
1095                  const RuntimeShape& input_shape, const uint8_t* input_data,
1096                  int32 multiplier, int32 shift, int32 bias,
1097                  const RuntimeShape& output_shape, uint8_t* output_data,
1098                  int start_height, int end_height)
1099       : op_params(op_params),
1100         input_shape(input_shape),
1101         input_data(input_data),
1102         multiplier(multiplier),
1103         shift(shift),
1104         bias(bias),
1105         output_shape(output_shape),
1106         output_data(output_data),
1107         start_height(start_height),
1108         end_height(end_height) {}
1109 
RunMeanWorkerTask1110   void Run() override {
1111     MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1112              output_shape, output_data, start_height, end_height);
1113   }
1114 
1115  private:
1116   const tflite::MeanParams& op_params;
1117   const RuntimeShape& input_shape;
1118   const uint8_t* input_data;
1119   int32 multiplier;
1120   int32 shift;
1121   int32 bias;
1122   const RuntimeShape& output_shape;
1123   uint8_t* output_data;
1124   int start_height;
1125   int end_height;
1126 };
1127 
Mean(const tflite::MeanParams & op_params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & unextended_output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,CpuBackendContext * cpu_backend_context)1128 inline void Mean(const tflite::MeanParams& op_params,
1129                  const RuntimeShape& unextended_input_shape,
1130                  const uint8_t* input_data, int32 input_zero_point,
1131                  float input_scale, const RuntimeShape& unextended_output_shape,
1132                  uint8_t* output_data, int32 output_zero_point,
1133                  float output_scale, CpuBackendContext* cpu_backend_context) {
1134   ruy::profiler::ScopeLabel label("Mean4D/Uint8");
1135   // Current implementation only supports dimension equals 4 and simultaneous
1136   // reduction over width and height.
1137   TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
1138   TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1139   const RuntimeShape input_shape =
1140       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1141   const RuntimeShape output_shape =
1142       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1143   const int output_height = output_shape.Dims(1);
1144   const int output_width = output_shape.Dims(2);
1145   const int output_depth = output_shape.Dims(3);
1146 
1147   TFLITE_CHECK_EQ(op_params.axis_count, 2);
1148   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1149                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1150   TFLITE_CHECK_EQ(output_height, 1);
1151   TFLITE_CHECK_EQ(output_width, 1);
1152 
1153   const int input_height = input_shape.Dims(1);
1154   const int input_width = input_shape.Dims(2);
1155   const float num_elements_in_axis = input_width * input_height;
1156 
1157   int32 bias =
1158       output_zero_point -
1159       static_cast<int32>(input_zero_point * input_scale / output_scale);
1160   float real_scale = input_scale / (num_elements_in_axis * output_scale);
1161 
1162   int32 multiplier, shift;
1163   QuantizeMultiplier(real_scale, &multiplier, &shift);
1164 
1165   constexpr int kMinDepthPerThread = 8;
1166   int thread_count = output_depth / kMinDepthPerThread;
1167   thread_count = thread_count > 0 ? thread_count : 1;
1168   const int capped_thread_count =
1169       std::min(thread_count, cpu_backend_context->max_num_threads());
1170 
1171   if (capped_thread_count == 1) {
1172     MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1173              output_shape, output_data, 0, output_depth);
1174   } else {
1175     // Instead parallel for batch, we loop for the output_depth since batch
1176     // is typical 1.
1177     std::vector<MeanWorkerTask> tasks;
1178     // TODO(b/131746020) don't create new heap allocations every time.
1179     // At least we make it a single heap allocation by using reserve().
1180     tasks.reserve(capped_thread_count);
1181     int depth_start = 0;
1182     for (int i = 0; i < capped_thread_count; ++i) {
1183       // Try to distribute the tasks as even as possible.
1184       int depth_end = depth_start +
1185                       (output_depth - depth_start) / (capped_thread_count - i);
1186       tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift,
1187                          bias, output_shape, output_data, depth_start,
1188                          depth_end);
1189       depth_start = depth_end;
1190     }
1191     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
1192                                     cpu_backend_context);
1193   }
1194 }
1195 
1196 template <typename T, typename U>
MeanGeneral(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data,const int * output_dims,const int output_num_dims,const int * axis,const int num_axis_dimensions,bool keep_dims,int * temp_index,int * resolved_axis,U * temp_sum)1197 inline bool MeanGeneral(const T* input_data, const int* input_dims,
1198                         const int input_num_dims, T* output_data,
1199                         const int* output_dims, const int output_num_dims,
1200                         const int* axis, const int num_axis_dimensions,
1201                         bool keep_dims, int* temp_index, int* resolved_axis,
1202                         U* temp_sum) {
1203   return reference_ops::Mean(input_data, input_dims, input_num_dims,
1204                              output_data, output_dims, output_num_dims, axis,
1205                              num_axis_dimensions, keep_dims, temp_index,
1206                              resolved_axis, temp_sum);
1207 }
1208 
1209 template <>
1210 inline bool MeanGeneral<float, float>(
1211     const float* input_data, const int* input_dims, const int input_num_dims,
1212     float* output_data, const int* output_dims, const int output_num_dims,
1213     const int* axis, const int num_axis_dimensions, bool keep_dims,
1214     int* temp_index, int* resolved_axis, float* temp_sum) {
1215   // Handle reduce_mean for the last dimensions.
1216   if (num_axis_dimensions == 1 && axis[0] == (input_num_dims - 1)) {
1217     ruy::profiler::ScopeLabel label("MeanLastDim/Float");
1218     int output_size = 1;
1219     for (int i = 0; i < input_num_dims - 1; ++i) {
1220       output_size *= input_dims[i];
1221     }
1222     const int last_input_dim = input_dims[axis[0]];
1223 
1224     // TODO(b/152563685): Consider use eigen to cover more general cases.
1225     const MatrixMap<const float> in_mat(input_data, last_input_dim,
1226                                         output_size);
1227     VectorMap<float> out(output_data, output_size, 1);
1228     out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
1229     return true;
1230   }
1231 
1232   return reference_ops::Mean(input_data, input_dims, input_num_dims,
1233                              output_data, output_dims, output_num_dims, axis,
1234                              num_axis_dimensions, keep_dims, temp_index,
1235                              resolved_axis, temp_sum);
1236 }
1237 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,CpuBackendContext * cpu_backend_context)1238 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1239                  const float* input_data, const RuntimeShape& filter_shape,
1240                  const float* filter_data, const RuntimeShape& bias_shape,
1241                  const float* bias_data, const RuntimeShape& output_shape,
1242                  float* output_data, const RuntimeShape& im2col_shape,
1243                  float* im2col_data, CpuBackendContext* cpu_backend_context) {
1244   const int stride_width = params.stride_width;
1245   const int stride_height = params.stride_height;
1246   const int dilation_width_factor = params.dilation_width_factor;
1247   const int dilation_height_factor = params.dilation_height_factor;
1248   const float output_activation_min = params.float_activation_min;
1249   const float output_activation_max = params.float_activation_max;
1250   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1251   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1252   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1253 
1254   ruy::profiler::ScopeLabel label("Conv");
1255 
1256   // NB: the float 0.0f value is represented by all zero bytes.
1257   const uint8 float_zero_byte = 0x00;
1258   const float* gemm_input_data = nullptr;
1259   const RuntimeShape* gemm_input_shape = nullptr;
1260   const int filter_width = filter_shape.Dims(2);
1261   const int filter_height = filter_shape.Dims(1);
1262   const bool need_dilated_im2col =
1263       dilation_width_factor != 1 || dilation_height_factor != 1;
1264   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1265                            filter_width != 1 || filter_height != 1;
1266   if (need_dilated_im2col) {
1267     DilatedIm2col(params, float_zero_byte, input_shape, input_data,
1268                   filter_shape, output_shape, im2col_data);
1269     gemm_input_data = im2col_data;
1270     gemm_input_shape = &im2col_shape;
1271   } else if (need_im2col) {
1272     TFLITE_DCHECK(im2col_data);
1273     Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
1274            input_data, im2col_shape, im2col_data);
1275     gemm_input_data = im2col_data;
1276     gemm_input_shape = &im2col_shape;
1277   } else {
1278     TFLITE_DCHECK(!im2col_data);
1279     gemm_input_data = input_data;
1280     gemm_input_shape = &input_shape;
1281   }
1282 
1283   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
1284   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
1285   int n = output_shape.Dims(3);
1286   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
1287 
1288 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1289   // The following code computes matrix multiplication c = a * transponse(b)
1290   // with CBLAS, where:
1291   // * `a` is a matrix with dimensions (m, k).
1292   // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
1293   // * `c` is a matrix with dimensions (m, n).
1294   // The naming of variables are aligned with CBLAS specification here.
1295   const float* a = gemm_input_data;
1296   const float* b = filter_data;
1297   float* c = output_data;
1298   // The stride of matrix a, b and c respectively.
1299   int stride_a = k;
1300   int stride_b = k;
1301   int stride_c = n;
1302 
1303   cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
1304               stride_a, b, stride_b, 0.0f, c, stride_c);
1305   optimized_ops::AddBiasAndEvalActivationFunction(
1306       output_activation_min, output_activation_max, bias_shape, bias_data,
1307       output_shape, output_data);
1308 #else
1309   // When an optimized CBLAS implementation is not available, fall back
1310   // to using cpu_backend_gemm.
1311   cpu_backend_gemm::MatrixParams<float> lhs_params;
1312   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1313   lhs_params.rows = n;
1314   lhs_params.cols = k;
1315   cpu_backend_gemm::MatrixParams<float> rhs_params;
1316   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1317   rhs_params.rows = k;
1318   rhs_params.cols = m;
1319   cpu_backend_gemm::MatrixParams<float> dst_params;
1320   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1321   dst_params.rows = n;
1322   dst_params.cols = m;
1323   cpu_backend_gemm::GemmParams<float, float> gemm_params;
1324   gemm_params.bias = bias_data;
1325   gemm_params.clamp_min = output_activation_min;
1326   gemm_params.clamp_max = output_activation_max;
1327   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1328                          dst_params, output_data, gemm_params,
1329                          cpu_backend_context);
1330 #endif  //  defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1331 }
1332 
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & accum_scratch_shape,int32_t * accum_scratch,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,CpuBackendContext * context)1333 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
1334                        const RuntimeShape& input_shape,
1335                        const int8_t* input_data,
1336                        const RuntimeShape& filter_shape,
1337                        const int8_t* filter_data,
1338                        const RuntimeShape& bias_shape, const float* bias_data,
1339                        const RuntimeShape& accum_scratch_shape,
1340                        int32_t* accum_scratch, const RuntimeShape& output_shape,
1341                        float* output_data, const RuntimeShape& im2col_shape,
1342                        int8_t* im2col_data, CpuBackendContext* context) {
1343   const int stride_width = params.stride_width;
1344   const int stride_height = params.stride_height;
1345   const int dilation_width_factor = params.dilation_width_factor;
1346   const int dilation_height_factor = params.dilation_height_factor;
1347   const float output_activation_min = params.float_activation_min;
1348   const float output_activation_max = params.float_activation_max;
1349   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1350   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1351   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1352 
1353   const int batch_size = input_shape.Dims(0);
1354   const int filter_width = filter_shape.Dims(2);
1355   const int filter_height = filter_shape.Dims(1);
1356 
1357   const int input_zero_point = 0;
1358   const int8_t* gemm_input_data = nullptr;
1359   int num_input;
1360   const bool need_dilated_im2col =
1361       dilation_width_factor != 1 || dilation_height_factor != 1;
1362   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1363                            filter_width != 1 || filter_height != 1;
1364 
1365   if (need_dilated_im2col) {
1366     DilatedIm2col(params, input_zero_point, input_shape, input_data,
1367                   filter_shape, output_shape, im2col_data);
1368     gemm_input_data = im2col_data;
1369     num_input = im2col_shape.FlatSize();
1370   } else if (need_im2col) {
1371     TFLITE_DCHECK(im2col_data);
1372     // symmetric quantization assumes zero point of 0.
1373 
1374     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1375            input_data, im2col_shape, im2col_data);
1376     gemm_input_data = im2col_data;
1377     num_input = im2col_shape.FlatSize();
1378   } else {
1379     TFLITE_DCHECK(!im2col_data);
1380     gemm_input_data = input_data;
1381     num_input = input_shape.FlatSize();
1382   }
1383 
1384   // Flatten 4D matrices into 2D matrices for matrix multiplication.
1385 
1386   // Flatten so that each filter has its own row.
1387   const int filter_rows = filter_shape.Dims(0);
1388   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1389 
1390   // In MatrixBatchVectorMultiplyAccumulate, each output value is the
1391   // dot product of one row of the first matrix with one row of the second
1392   // matrix. Therefore, the number of cols in each matrix are equivalent.
1393   //
1394   // After Im2Col, each input patch becomes a row.
1395   const int gemm_input_cols = filter_cols;
1396   const int gemm_input_rows = num_input / gemm_input_cols;
1397 
1398   const int output_cols = output_shape.Dims(3);
1399   const int output_rows = FlatSizeSkipDim(output_shape, 3);
1400   TFLITE_DCHECK_EQ(output_cols, filter_rows);
1401   TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
1402   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
1403 
1404   // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
1405   // input matrix has its own scale factor. This code duplicates the scale
1406   // factors for each row in the same batch.
1407   const int rows_per_batch = gemm_input_rows / batch_size;
1408   for (int i = gemm_input_rows - 1; i >= 0; --i) {
1409     scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
1410   }
1411 
1412   std::fill_n(output_data, output_rows * output_cols, 0.0f);
1413 
1414   // The scratch buffer must have the same size as the output.
1415   TFLITE_DCHECK_EQ(accum_scratch_shape.FlatSize(), output_shape.FlatSize());
1416   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
1417       filter_data, filter_rows, filter_cols, gemm_input_data,
1418       scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
1419       output_data, context);
1420   AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
1421                                    bias_shape, bias_data, output_shape,
1422                                    output_data);
1423 }
1424 
HybridConvPerChannel(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,const float * per_channel_scale,int32_t * input_offset,const RuntimeShape & scratch_shape,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * cpu_backend_context)1425 inline void HybridConvPerChannel(
1426     const ConvParams& params, float* scaling_factors_ptr,
1427     const RuntimeShape& input_shape, const int8_t* input_data,
1428     const RuntimeShape& filter_shape, const int8_t* filter_data,
1429     const RuntimeShape& bias_shape, const float* bias_data,
1430     const RuntimeShape& output_shape, float* output_data,
1431     const RuntimeShape& im2col_shape, int8_t* im2col_data,
1432     const float* per_channel_scale, int32_t* input_offset,
1433     const RuntimeShape& scratch_shape, int32_t* scratch, int32_t* row_sums,
1434     bool* compute_row_sums, CpuBackendContext* cpu_backend_context) {
1435   ruy::profiler::ScopeLabel label("ConvHybridPerChannel");
1436   const int stride_width = params.stride_width;
1437   const int stride_height = params.stride_height;
1438   const int dilation_width_factor = params.dilation_width_factor;
1439   const int dilation_height_factor = params.dilation_height_factor;
1440   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1441   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1442   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1443 
1444   const int8* gemm_input_data = nullptr;
1445   const RuntimeShape* gemm_input_shape = nullptr;
1446   const int filter_width = filter_shape.Dims(2);
1447   const int filter_height = filter_shape.Dims(1);
1448   const bool need_dilated_im2col =
1449       dilation_width_factor != 1 || dilation_height_factor != 1;
1450   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1451                            filter_width != 1 || filter_height != 1;
1452 
1453   const int batch_size = input_shape.Dims(0);
1454 
1455   if (need_dilated_im2col) {
1456     TFLITE_DCHECK(im2col_data);
1457     optimized_ops::DilatedIm2col(params, input_shape, input_data, filter_shape,
1458                                  output_shape, im2col_data, input_offset,
1459                                  batch_size);
1460     gemm_input_data = im2col_data;
1461     gemm_input_shape = &im2col_shape;
1462   } else if (need_im2col) {
1463     Im2col(params, filter_height, filter_width, input_offset, batch_size,
1464            input_shape, input_data, im2col_shape, im2col_data);
1465     gemm_input_data = im2col_data;
1466     gemm_input_shape = &im2col_shape;
1467   } else {
1468     TFLITE_DCHECK(!im2col_data);
1469     gemm_input_data = input_data;
1470     gemm_input_shape = &input_shape;
1471   }
1472 
1473   const int filter_rows = filter_shape.Dims(0);
1474   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1475 
1476   const int gemm_input_rows = gemm_input_shape->Dims(3);
1477   const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1478   const int output_rows = output_shape.Dims(3);
1479   const int output_cols =
1480       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1481 
1482   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1483   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1484   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1485   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1486   TFLITE_DCHECK_EQ(scratch_shape.FlatSize(), output_shape.FlatSize());
1487   if (!compute_row_sums || *compute_row_sums) {
1488     tensor_utils::ReductionSumVector(filter_data, row_sums, filter_rows,
1489                                      filter_cols);
1490     if (compute_row_sums) {
1491       *compute_row_sums = false;
1492     }
1493   }
1494 
1495   cpu_backend_gemm::MatrixParams<int8> lhs_params;
1496   lhs_params.rows = filter_rows;
1497   lhs_params.cols = filter_cols;
1498   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1499 
1500   cpu_backend_gemm::MatrixParams<int8> rhs_params;
1501   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1502   rhs_params.rows = gemm_input_rows;
1503   rhs_params.cols = gemm_input_cols;
1504 
1505   cpu_backend_gemm::MatrixParams<int32> dst_params;
1506   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1507   dst_params.rows = output_rows;
1508   dst_params.cols = output_cols;
1509 
1510   // TODO(b/149003801): Use hybrid gemm once supported in Ruy.
1511   cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
1512   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1513                          dst_params, scratch, gemm_params, cpu_backend_context);
1514 
1515   MatrixMap<float> out_mat(output_data, filter_rows, output_cols);
1516   MatrixMap<int32_t> in_mat(scratch, filter_rows, output_cols);
1517   VectorMap<const float> bias_data_vec(bias_data, filter_rows, 1);
1518   VectorMap<int32_t> row_sums_vec(row_sums, filter_rows, 1);
1519   VectorMap<const float> per_channel_scale_vec(per_channel_scale, filter_rows,
1520                                                1);
1521   const int cols_per_batch = output_cols / batch_size;
1522   for (int c = 0; c < output_cols; c++) {
1523     const int b = c / cols_per_batch;
1524     const float input_scale = scaling_factors_ptr[b];
1525     const int32_t zero_point = input_offset[b];
1526     out_mat.col(c) =
1527         (((in_mat.col(c) - (row_sums_vec * zero_point))
1528               .cast<float>()
1529               .cwiseProduct((per_channel_scale_vec * input_scale))) +
1530          bias_data_vec)
1531             .cwiseMin(params.float_activation_max)
1532             .cwiseMax(params.float_activation_min);
1533   }
1534 }
1535 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,CpuBackendContext * cpu_backend_context)1536 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1537                  const uint8* input_data, const RuntimeShape& filter_shape,
1538                  const uint8* filter_data, const RuntimeShape& bias_shape,
1539                  const int32* bias_data, const RuntimeShape& output_shape,
1540                  uint8* output_data, const RuntimeShape& im2col_shape,
1541                  uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
1542   ruy::profiler::ScopeLabel label("Conv/8bit");
1543 
1544   const int stride_width = params.stride_width;
1545   const int stride_height = params.stride_height;
1546   const int dilation_width_factor = params.dilation_width_factor;
1547   const int dilation_height_factor = params.dilation_height_factor;
1548   const int32 input_offset = params.input_offset;
1549   const int32 filter_offset = params.weights_offset;
1550   const int32 output_offset = params.output_offset;
1551   const int32 output_multiplier = params.output_multiplier;
1552   const int output_shift = params.output_shift;
1553   const int32 output_activation_min = params.quantized_activation_min;
1554   const int32 output_activation_max = params.quantized_activation_max;
1555   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1556   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1557   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1558 
1559   const uint8* gemm_input_data = nullptr;
1560   const RuntimeShape* gemm_input_shape = nullptr;
1561   const int filter_width = filter_shape.Dims(2);
1562   const int filter_height = filter_shape.Dims(1);
1563   const bool need_dilated_im2col =
1564       dilation_width_factor != 1 || dilation_height_factor != 1;
1565   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1566                            filter_width != 1 || filter_height != 1;
1567   if (need_dilated_im2col) {
1568     TFLITE_DCHECK(im2col_data);
1569     const int input_zero_point = -input_offset;
1570     TFLITE_DCHECK_GE(input_zero_point, 0);
1571     TFLITE_DCHECK_LE(input_zero_point, 255);
1572     DilatedIm2col(params, input_zero_point, input_shape, input_data,
1573                   filter_shape, output_shape, im2col_data);
1574     gemm_input_data = im2col_data;
1575     gemm_input_shape = &im2col_shape;
1576   } else if (need_im2col) {
1577     TFLITE_DCHECK(im2col_data);
1578     const int input_zero_point = -input_offset;
1579     TFLITE_DCHECK_GE(input_zero_point, 0);
1580     TFLITE_DCHECK_LE(input_zero_point, 255);
1581     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1582            input_data, im2col_shape, im2col_data);
1583     gemm_input_data = im2col_data;
1584     gemm_input_shape = &im2col_shape;
1585   } else {
1586     TFLITE_DCHECK(!im2col_data);
1587     gemm_input_data = input_data;
1588     gemm_input_shape = &input_shape;
1589   }
1590 
1591   const int gemm_input_rows = gemm_input_shape->Dims(3);
1592   // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
1593   // The root cause has not yet been identified though. Same applies below for
1594   // the other calls commented out. This is a partial rollback of cl/196819423.
1595   // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1596   const int gemm_input_cols = gemm_input_shape->Dims(0) *
1597                               gemm_input_shape->Dims(1) *
1598                               gemm_input_shape->Dims(2);
1599   const int filter_rows = filter_shape.Dims(0);
1600   // See b/79927784.
1601   // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1602   const int filter_cols =
1603       filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
1604   const int output_rows = output_shape.Dims(3);
1605   // See b/79927784.
1606   // const int output_cols = FlatSizeSkipDim(output_shape, 3);
1607   const int output_cols =
1608       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1609   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1610   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1611   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1612   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1613 
1614   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
1615   lhs_params.rows = filter_rows;
1616   lhs_params.cols = filter_cols;
1617   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1618   lhs_params.zero_point = -filter_offset;
1619   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
1620   rhs_params.rows = gemm_input_rows;
1621   rhs_params.cols = gemm_input_cols;
1622   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1623   rhs_params.zero_point = -input_offset;
1624   cpu_backend_gemm::MatrixParams<uint8> dst_params;
1625   dst_params.rows = output_rows;
1626   dst_params.cols = output_cols;
1627   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1628   dst_params.zero_point = output_offset;
1629   cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
1630   gemm_params.bias = bias_data;
1631   gemm_params.clamp_min = output_activation_min;
1632   gemm_params.clamp_max = output_activation_max;
1633   gemm_params.multiplier_fixedpoint = output_multiplier;
1634   gemm_params.multiplier_exponent = output_shift;
1635   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1636                          dst_params, output_data, gemm_params,
1637                          cpu_backend_context);
1638 }
1639 
1640 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1641 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
1642                          const RuntimeShape& unextended_input_shape,
1643                          const T* input_data,
1644                          const RuntimeShape& unextended_output_shape,
1645                          T* output_data) {
1646   ruy::profiler::ScopeLabel label("DepthToSpace");
1647 
1648   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1649   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1650   const RuntimeShape input_shape =
1651       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1652   const RuntimeShape output_shape =
1653       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1654 
1655   const int input_depth = input_shape.Dims(3);
1656   const int input_width = input_shape.Dims(2);
1657   const int input_height = input_shape.Dims(1);
1658 
1659   const int output_depth = output_shape.Dims(3);
1660   const int batch_size = output_shape.Dims(0);
1661 
1662   // Number of continuous values that we can copy in one interation.
1663   const int stride = op_params.block_size * output_depth;
1664 
1665   for (int batch = 0; batch < batch_size; ++batch) {
1666     for (int in_h = 0; in_h < input_height; ++in_h) {
1667       const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
1668       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1669         const T* src = input_ptr;
1670         for (int in_w = 0; in_w < input_width; ++in_w) {
1671           memcpy(output_data, src, stride * sizeof(T));
1672           output_data += stride;
1673           src += input_depth;
1674         }
1675         input_ptr += stride;
1676       }
1677     }
1678   }
1679 }
1680 
1681 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1682 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
1683                          const RuntimeShape& unextended_input_shape,
1684                          const T* input_data,
1685                          const RuntimeShape& unextended_output_shape,
1686                          T* output_data) {
1687   ruy::profiler::ScopeLabel label("SpaceToDepth");
1688 
1689   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1690   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1691   const RuntimeShape input_shape =
1692       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1693   const RuntimeShape output_shape =
1694       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1695 
1696   const int output_depth = output_shape.Dims(3);
1697   const int output_width = output_shape.Dims(2);
1698   const int output_height = output_shape.Dims(1);
1699 
1700   const int input_depth = input_shape.Dims(3);
1701   const int batch_size = input_shape.Dims(0);
1702 
1703   // Number of continuous values that we can copy in one interation.
1704   const int stride = op_params.block_size * input_depth;
1705 
1706   for (int batch = 0; batch < batch_size; ++batch) {
1707     for (int out_h = 0; out_h < output_height; ++out_h) {
1708       T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
1709       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1710         T* dst = output_ptr;
1711         for (int out_w = 0; out_w < output_width; ++out_w) {
1712           memcpy(dst, input_data, stride * sizeof(T));
1713           input_data += stride;
1714           dst += output_depth;
1715         }
1716         output_ptr += stride;
1717       }
1718     }
1719   }
1720 }
1721 
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1722 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
1723                  const RuntimeShape& output_shape, float* output_data) {
1724   ruy::profiler::ScopeLabel label("Relu (not fused)");
1725 
1726   const auto input = MapAsVector(input_data, input_shape);
1727   auto output = MapAsVector(output_data, output_shape);
1728   output = input.cwiseMax(0.0f);
1729 }
1730 
1731 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1732                             const RuntimeShape& input_shape,
1733                             const float* input_data,
1734                             const RuntimeShape& output_shape,
1735                             float* output_data, float epsilon = 1e-6) {
1736   ruy::profiler::ScopeLabel label("L2Normalization");
1737   const int trailing_dim = input_shape.DimensionsCount() - 1;
1738   const int outer_size =
1739       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1740   const int depth =
1741       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1742   for (int i = 0; i < outer_size; ++i) {
1743     float squared_l2_norm = 0;
1744     for (int c = 0; c < depth; ++c) {
1745       const float val = input_data[c];
1746       squared_l2_norm += val * val;
1747     }
1748     float l2_norm = std::sqrt(squared_l2_norm);
1749     l2_norm = std::max(l2_norm, epsilon);
1750     for (int c = 0; c < depth; ++c) {
1751       *output_data = *input_data / l2_norm;
1752       ++output_data;
1753       ++input_data;
1754     }
1755   }
1756 }
1757 
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)1758 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1759                             const RuntimeShape& input_shape,
1760                             const uint8* input_data,
1761                             const RuntimeShape& output_shape,
1762                             uint8* output_data) {
1763   ruy::profiler::ScopeLabel label("L2Normalization/8bit");
1764   const int trailing_dim = input_shape.DimensionsCount() - 1;
1765   const int depth =
1766       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1767   const int outer_size =
1768       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1769   const int32 input_zero_point = op_params.input_zero_point;
1770   for (int i = 0; i < outer_size; ++i) {
1771     int32 square_l2_norm = 0;
1772     for (int c = 0; c < depth; c++) {
1773       // Note that input_data advances by depth in the second pass below.
1774       int32 diff = input_data[c] - input_zero_point;
1775       square_l2_norm += diff * diff;
1776     }
1777     // TODO(b/29395854): add clamping to TOCO and TF Lite kernel
1778     // for all zero tensors in the input_data
1779     int32 inv_l2norm_multiplier;
1780     int inv_l2norm_shift;
1781     GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
1782                                      &inv_l2norm_multiplier, &inv_l2norm_shift);
1783 
1784     for (int c = 0; c < depth; c++) {
1785       int32 diff = *input_data - input_zero_point;
1786       int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
1787           128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
1788       int32 unclamped_output_val = 128 + rescaled_diff;
1789       int32 output_val = std::min(255, std::max(0, unclamped_output_val));
1790       *output_data = static_cast<uint8>(output_val);
1791       ++input_data;
1792       ++output_data;
1793     }
1794   }
1795 }
1796 
AddElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1797 inline void AddElementwise(int size, const ArithmeticParams& params,
1798                            const float* input1_data, const float* input2_data,
1799                            float* output_data) {
1800   int i = 0;
1801 
1802 #ifdef USE_NEON
1803   const auto activation_min = vdupq_n_f32(params.float_activation_min);
1804   const auto activation_max = vdupq_n_f32(params.float_activation_max);
1805   for (; i <= size - 16; i += 16) {
1806     auto a10 = vld1q_f32(input1_data + i);
1807     auto a11 = vld1q_f32(input1_data + i + 4);
1808     auto a12 = vld1q_f32(input1_data + i + 8);
1809     auto a13 = vld1q_f32(input1_data + i + 12);
1810     auto a20 = vld1q_f32(input2_data + i);
1811     auto a21 = vld1q_f32(input2_data + i + 4);
1812     auto a22 = vld1q_f32(input2_data + i + 8);
1813     auto a23 = vld1q_f32(input2_data + i + 12);
1814     auto x0 = vaddq_f32(a10, a20);
1815     auto x1 = vaddq_f32(a11, a21);
1816     auto x2 = vaddq_f32(a12, a22);
1817     auto x3 = vaddq_f32(a13, a23);
1818     x0 = vmaxq_f32(activation_min, x0);
1819     x1 = vmaxq_f32(activation_min, x1);
1820     x2 = vmaxq_f32(activation_min, x2);
1821     x3 = vmaxq_f32(activation_min, x3);
1822     x0 = vminq_f32(activation_max, x0);
1823     x1 = vminq_f32(activation_max, x1);
1824     x2 = vminq_f32(activation_max, x2);
1825     x3 = vminq_f32(activation_max, x3);
1826     vst1q_f32(output_data + i, x0);
1827     vst1q_f32(output_data + i + 4, x1);
1828     vst1q_f32(output_data + i + 8, x2);
1829     vst1q_f32(output_data + i + 12, x3);
1830   }
1831   for (; i <= size - 4; i += 4) {
1832     auto a1 = vld1q_f32(input1_data + i);
1833     auto a2 = vld1q_f32(input2_data + i);
1834     auto x = vaddq_f32(a1, a2);
1835     x = vmaxq_f32(activation_min, x);
1836     x = vminq_f32(activation_max, x);
1837     vst1q_f32(output_data + i, x);
1838   }
1839 #endif  // NEON
1840 
1841   for (; i < size; i++) {
1842     auto x = input1_data[i] + input2_data[i];
1843     output_data[i] = ActivationFunctionWithMinMax(
1844         x, params.float_activation_min, params.float_activation_max);
1845   }
1846 }
1847 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1848 inline void Add(const ArithmeticParams& params,
1849                 const RuntimeShape& input1_shape, const float* input1_data,
1850                 const RuntimeShape& input2_shape, const float* input2_data,
1851                 const RuntimeShape& output_shape, float* output_data) {
1852   ruy::profiler::ScopeLabel label("Add");
1853   const int flat_size =
1854       MatchingElementsSize(input1_shape, input2_shape, output_shape);
1855   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1856 }
1857 
1858 // Element-wise add that can often be used for inner loop of broadcast add as
1859 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)1860 inline void AddElementwise(int size, const ArithmeticParams& params,
1861                            const uint8* input1_data, const uint8* input2_data,
1862                            uint8* output_data) {
1863   ruy::profiler::ScopeLabel label("AddElementwise/8bit");
1864   int i = 0;
1865   TFLITE_DCHECK_GT(params.input1_offset, -256);
1866   TFLITE_DCHECK_GT(params.input2_offset, -256);
1867   TFLITE_DCHECK_LT(params.input1_offset, 256);
1868   TFLITE_DCHECK_LT(params.input2_offset, 256);
1869 #ifdef USE_NEON
1870   const uint8x8_t output_activation_min_vector =
1871       vdup_n_u8(params.quantized_activation_min);
1872   const uint8x8_t output_activation_max_vector =
1873       vdup_n_u8(params.quantized_activation_max);
1874   for (; i <= size - 8; i += 8) {
1875     const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
1876     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1877     const int16x8_t input1_val_s16 =
1878         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1879     const int16x8_t input2_val_s16 =
1880         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1881     const int16x8_t input1_val =
1882         vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1883     const int16x8_t input2_val =
1884         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1885     const int16x4_t input1_val_high = vget_high_s16(input1_val);
1886     const int16x4_t input1_val_low = vget_low_s16(input1_val);
1887     const int16x4_t input2_val_high = vget_high_s16(input2_val);
1888     const int16x4_t input2_val_low = vget_low_s16(input2_val);
1889     int32x4_t x11 = vmovl_s16(input1_val_low);
1890     int32x4_t x12 = vmovl_s16(input1_val_high);
1891     int32x4_t x21 = vmovl_s16(input2_val_low);
1892     int32x4_t x22 = vmovl_s16(input2_val_high);
1893     const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1894     x11 = vshlq_s32(x11, left_shift_dup);
1895     x12 = vshlq_s32(x12, left_shift_dup);
1896     x21 = vshlq_s32(x21, left_shift_dup);
1897     x22 = vshlq_s32(x22, left_shift_dup);
1898     x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1899     x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1900     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1901     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1902     const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1903     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1904     x11 = vshlq_s32(x11, input1_shift_dup);
1905     x12 = vshlq_s32(x12, input1_shift_dup);
1906     x21 = vshlq_s32(x21, input2_shift_dup);
1907     x22 = vshlq_s32(x22, input2_shift_dup);
1908     int32x4_t s1 = vaddq_s32(x11, x21);
1909     int32x4_t s2 = vaddq_s32(x12, x22);
1910     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1911     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1912     using gemmlowp::RoundingDivideByPOT;
1913     s1 = RoundingDivideByPOT(s1, -params.output_shift);
1914     s2 = RoundingDivideByPOT(s2, -params.output_shift);
1915     const int16x4_t s1_narrowed = vmovn_s32(s1);
1916     const int16x4_t s2_narrowed = vmovn_s32(s2);
1917     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1918                                   vdupq_n_s16(params.output_offset));
1919     const uint8x8_t clamped =
1920         vmax_u8(output_activation_min_vector,
1921                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1922     vst1_u8(output_data + i, clamped);
1923   }
1924 #endif  // NEON
1925 
1926   for (; i < size; ++i) {
1927     const int32 input1_val = params.input1_offset + input1_data[i];
1928     const int32 input2_val = params.input2_offset + input2_data[i];
1929     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1930     const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1931     const int32 scaled_input1_val =
1932         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1933             shifted_input1_val, params.input1_multiplier, params.input1_shift);
1934     const int32 scaled_input2_val =
1935         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1936             shifted_input2_val, params.input2_multiplier, params.input2_shift);
1937     const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1938     const int32 raw_output =
1939         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1940             raw_sum, params.output_multiplier, params.output_shift) +
1941         params.output_offset;
1942     const int32 clamped_output =
1943         std::min(params.quantized_activation_max,
1944                  std::max(params.quantized_activation_min, raw_output));
1945     output_data[i] = static_cast<uint8>(clamped_output);
1946   }
1947 }
1948 
1949 // Scalar-broadcast add that can be used for inner loop of more general
1950 // broadcast add, so that, for example, scalar-broadcast with batch will still
1951 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)1952 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1953                                uint8 input1_data, const uint8* input2_data,
1954                                uint8* output_data) {
1955   using gemmlowp::RoundingDivideByPOT;
1956 
1957   ruy::profiler::ScopeLabel label("AddScalarBroadcast/8bit");
1958   TFLITE_DCHECK_GT(params.input1_offset, -256);
1959   TFLITE_DCHECK_GT(params.input2_offset, -256);
1960   TFLITE_DCHECK_LT(params.input1_offset, 256);
1961   TFLITE_DCHECK_LT(params.input2_offset, 256);
1962 
1963   int i = 0;
1964 
1965 #ifdef USE_NEON
1966   const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1967   const uint8x8_t output_activation_min_vector =
1968       vdup_n_u8(params.quantized_activation_min);
1969   const uint8x8_t output_activation_max_vector =
1970       vdup_n_u8(params.quantized_activation_max);
1971 
1972   // Process broadcast scalar.
1973   const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
1974   const int16x8_t input1_val_s16 =
1975       vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1976   const int16x8_t input1_val =
1977       vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1978   const int16x4_t input1_val_high = vget_high_s16(input1_val);
1979   const int16x4_t input1_val_low = vget_low_s16(input1_val);
1980   int32x4_t x11 = vmovl_s16(input1_val_low);
1981   int32x4_t x12 = vmovl_s16(input1_val_high);
1982   x11 = vshlq_s32(x11, left_shift_dup);
1983   x12 = vshlq_s32(x12, left_shift_dup);
1984   x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1985   x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1986   const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1987   x11 = vshlq_s32(x11, input1_shift_dup);
1988   x12 = vshlq_s32(x12, input1_shift_dup);
1989 
1990   for (; i <= size - 8; i += 8) {
1991     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1992     const int16x8_t input2_val_s16 =
1993         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1994     const int16x8_t input2_val =
1995         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1996     const int16x4_t input2_val_high = vget_high_s16(input2_val);
1997     const int16x4_t input2_val_low = vget_low_s16(input2_val);
1998     int32x4_t x21 = vmovl_s16(input2_val_low);
1999     int32x4_t x22 = vmovl_s16(input2_val_high);
2000     x21 = vshlq_s32(x21, left_shift_dup);
2001     x22 = vshlq_s32(x22, left_shift_dup);
2002     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2003     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2004     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2005     x21 = vshlq_s32(x21, input2_shift_dup);
2006     x22 = vshlq_s32(x22, input2_shift_dup);
2007     int32x4_t s1 = vaddq_s32(x11, x21);
2008     int32x4_t s2 = vaddq_s32(x12, x22);
2009     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2010     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2011     s1 = RoundingDivideByPOT(s1, -params.output_shift);
2012     s2 = RoundingDivideByPOT(s2, -params.output_shift);
2013     const int16x4_t s1_narrowed = vmovn_s32(s1);
2014     const int16x4_t s2_narrowed = vmovn_s32(s2);
2015     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2016                                   vdupq_n_s16(params.output_offset));
2017     const uint8x8_t clamped =
2018         vmax_u8(output_activation_min_vector,
2019                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2020     vst1_u8(output_data + i, clamped);
2021   }
2022 #endif  // NEON
2023 
2024   if (i < size) {
2025     // Process broadcast scalar.
2026     const int32 input1_val = params.input1_offset + input1_data;
2027     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2028     const int32 scaled_input1_val =
2029         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2030             shifted_input1_val, params.input1_multiplier, params.input1_shift);
2031 
2032     for (; i < size; ++i) {
2033       const int32 input2_val = params.input2_offset + input2_data[i];
2034       const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2035       const int32 scaled_input2_val =
2036           MultiplyByQuantizedMultiplierSmallerThanOneExp(
2037               shifted_input2_val, params.input2_multiplier,
2038               params.input2_shift);
2039       const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2040       const int32 raw_output =
2041           MultiplyByQuantizedMultiplierSmallerThanOneExp(
2042               raw_sum, params.output_multiplier, params.output_shift) +
2043           params.output_offset;
2044       const int32 clamped_output =
2045           std::min(params.quantized_activation_max,
2046                    std::max(params.quantized_activation_min, raw_output));
2047       output_data[i] = static_cast<uint8>(clamped_output);
2048     }
2049   }
2050 }
2051 
2052 // Scalar-broadcast add that can be used for inner loop of more general
2053 // broadcast add, so that, for example, scalar-broadcast with batch will still
2054 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,float broadcast_value,const float * input2_data,float * output_data)2055 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
2056                                float broadcast_value, const float* input2_data,
2057                                float* output_data) {
2058   int i = 0;
2059 #ifdef USE_NEON
2060   const float32x4_t output_activation_min_vector =
2061       vdupq_n_f32(params.float_activation_min);
2062   const float32x4_t output_activation_max_vector =
2063       vdupq_n_f32(params.float_activation_max);
2064   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2065   for (; i <= size - 4; i += 4) {
2066     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2067 
2068     const float32x4_t output =
2069         vaddq_f32(input2_val_original, broadcast_value_dup);
2070 
2071     const float32x4_t clamped =
2072         vmaxq_f32(output_activation_min_vector,
2073                   vminq_f32(output_activation_max_vector, output));
2074     vst1q_f32(output_data + i, clamped);
2075   }
2076 #endif  // NEON
2077 
2078   for (; i < size; ++i) {
2079     auto x = broadcast_value + input2_data[i];
2080     output_data[i] = ActivationFunctionWithMinMax(
2081         x, params.float_activation_min, params.float_activation_max);
2082   }
2083 }
2084 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2085 inline void Add(const ArithmeticParams& params,
2086                 const RuntimeShape& input1_shape, const uint8* input1_data,
2087                 const RuntimeShape& input2_shape, const uint8* input2_data,
2088                 const RuntimeShape& output_shape, uint8* output_data) {
2089   TFLITE_DCHECK_LE(params.quantized_activation_min,
2090                    params.quantized_activation_max);
2091   ruy::profiler::ScopeLabel label("Add/8bit");
2092   const int flat_size =
2093       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2094 
2095   TFLITE_DCHECK_GT(params.input1_offset, -256);
2096   TFLITE_DCHECK_GT(params.input2_offset, -256);
2097   TFLITE_DCHECK_LT(params.input1_offset, 256);
2098   TFLITE_DCHECK_LT(params.input2_offset, 256);
2099   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
2100 }
2101 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2102 inline void Add(const ArithmeticParams& params,
2103                 const RuntimeShape& input1_shape, const int16* input1_data,
2104                 const RuntimeShape& input2_shape, const int16* input2_data,
2105                 const RuntimeShape& output_shape, int16* output_data) {
2106   ruy::profiler::ScopeLabel label("Add/Int16");
2107   TFLITE_DCHECK_LE(params.quantized_activation_min,
2108                    params.quantized_activation_max);
2109 
2110   const int input1_shift = params.input1_shift;
2111   const int flat_size =
2112       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2113   const int16 output_activation_min = params.quantized_activation_min;
2114   const int16 output_activation_max = params.quantized_activation_max;
2115 
2116   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
2117   TFLITE_DCHECK_LE(input1_shift, 0);
2118   TFLITE_DCHECK_LE(params.input2_shift, 0);
2119   const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
2120   const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
2121   const int input_right_shift =
2122       input1_shift == 0 ? -params.input2_shift : -input1_shift;
2123 
2124   for (int i = 0; i < flat_size; i++) {
2125     // F0 uses 0 integer bits, range [-1, 1].
2126     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2127 
2128     F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
2129     F0 scaled_input = F0::FromRaw(
2130         gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
2131     F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
2132     const int16 raw_output = result.raw();
2133     const int16 clamped_output = std::min(
2134         output_activation_max, std::max(output_activation_min, raw_output));
2135     output_data[i] = clamped_output;
2136   }
2137 }
2138 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2139 inline void Add(const ArithmeticParams& params,
2140                 const RuntimeShape& input1_shape, const int32* input1_data,
2141                 const RuntimeShape& input2_shape, const int32* input2_data,
2142                 const RuntimeShape& output_shape, int32* output_data) {
2143   ruy::profiler::ScopeLabel label("Add/int32");
2144 
2145   auto input1_map = MapAsVector(input1_data, input1_shape);
2146   auto input2_map = MapAsVector(input2_data, input2_shape);
2147   auto output_map = MapAsVector(output_data, output_shape);
2148   if (input1_shape == input2_shape) {
2149     output_map.array() = (input1_map.array() + input2_map.array())
2150                              .cwiseMax(params.quantized_activation_min)
2151                              .cwiseMin(params.quantized_activation_max);
2152   } else if (input2_shape.FlatSize() == 1) {
2153     auto scalar = input2_data[0];
2154     output_map.array() = (input1_map.array() + scalar)
2155                              .cwiseMax(params.quantized_activation_min)
2156                              .cwiseMin(params.quantized_activation_max);
2157   } else if (input1_shape.FlatSize() == 1) {
2158     auto scalar = input1_data[0];
2159     output_map.array() = (scalar + input2_map.array())
2160                              .cwiseMax(params.quantized_activation_min)
2161                              .cwiseMin(params.quantized_activation_max);
2162   } else {
2163     reference_ops::BroadcastAdd4DSlow(params, input1_shape, input1_data,
2164                                       input2_shape, input2_data, output_shape,
2165                                       output_data);
2166   }
2167 }
2168 
2169 template <typename T>
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2170 inline void BroadcastAddDispatch(
2171     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2172     const T* input1_data, const RuntimeShape& input2_shape,
2173     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2174   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2175     return BroadcastAdd4DSlow(params, input1_shape, input1_data, input2_shape,
2176                               input2_data, output_shape, output_data);
2177   }
2178 
2179   BinaryBroadcastFiveFold(
2180       params, input1_shape, input1_data, input2_shape, input2_data,
2181       output_shape, output_data,
2182       static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2183                            T*)>(AddElementwise),
2184       static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2185           AddScalarBroadcast));
2186 }
2187 
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2188 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
2189                                  const RuntimeShape& unswitched_input1_shape,
2190                                  const uint8* unswitched_input1_data,
2191                                  const RuntimeShape& unswitched_input2_shape,
2192                                  const uint8* unswitched_input2_data,
2193                                  const RuntimeShape& output_shape,
2194                                  uint8* output_data) {
2195   BroadcastAddDispatch(unswitched_params, unswitched_input1_shape,
2196                        unswitched_input1_data, unswitched_input2_shape,
2197                        unswitched_input2_data, output_shape, output_data);
2198 }
2199 
BroadcastAddFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2200 inline void BroadcastAddFivefold(const ArithmeticParams& params,
2201                                  const RuntimeShape& unswitched_input1_shape,
2202                                  const float* unswitched_input1_data,
2203                                  const RuntimeShape& unswitched_input2_shape,
2204                                  const float* unswitched_input2_data,
2205                                  const RuntimeShape& output_shape,
2206                                  float* output_data) {
2207   BroadcastAddDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2208                        unswitched_input2_shape, unswitched_input2_data,
2209                        output_shape, output_data);
2210 }
2211 
MulElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)2212 inline void MulElementwise(int size, const ArithmeticParams& params,
2213                            const float* input1_data, const float* input2_data,
2214                            float* output_data) {
2215   const float output_activation_min = params.float_activation_min;
2216   const float output_activation_max = params.float_activation_max;
2217 
2218   int i = 0;
2219 #ifdef USE_NEON
2220   const auto activation_min = vdupq_n_f32(output_activation_min);
2221   const auto activation_max = vdupq_n_f32(output_activation_max);
2222   for (; i <= size - 16; i += 16) {
2223     auto a10 = vld1q_f32(input1_data + i);
2224     auto a11 = vld1q_f32(input1_data + i + 4);
2225     auto a12 = vld1q_f32(input1_data + i + 8);
2226     auto a13 = vld1q_f32(input1_data + i + 12);
2227     auto a20 = vld1q_f32(input2_data + i);
2228     auto a21 = vld1q_f32(input2_data + i + 4);
2229     auto a22 = vld1q_f32(input2_data + i + 8);
2230     auto a23 = vld1q_f32(input2_data + i + 12);
2231     auto x0 = vmulq_f32(a10, a20);
2232     auto x1 = vmulq_f32(a11, a21);
2233     auto x2 = vmulq_f32(a12, a22);
2234     auto x3 = vmulq_f32(a13, a23);
2235 
2236     x0 = vmaxq_f32(activation_min, x0);
2237     x1 = vmaxq_f32(activation_min, x1);
2238     x2 = vmaxq_f32(activation_min, x2);
2239     x3 = vmaxq_f32(activation_min, x3);
2240     x0 = vminq_f32(activation_max, x0);
2241     x1 = vminq_f32(activation_max, x1);
2242     x2 = vminq_f32(activation_max, x2);
2243     x3 = vminq_f32(activation_max, x3);
2244 
2245     vst1q_f32(output_data + i, x0);
2246     vst1q_f32(output_data + i + 4, x1);
2247     vst1q_f32(output_data + i + 8, x2);
2248     vst1q_f32(output_data + i + 12, x3);
2249   }
2250   for (; i <= size - 4; i += 4) {
2251     auto a1 = vld1q_f32(input1_data + i);
2252     auto a2 = vld1q_f32(input2_data + i);
2253     auto x = vmulq_f32(a1, a2);
2254 
2255     x = vmaxq_f32(activation_min, x);
2256     x = vminq_f32(activation_max, x);
2257 
2258     vst1q_f32(output_data + i, x);
2259   }
2260 #endif  // NEON
2261 
2262   for (; i < size; i++) {
2263     auto x = input1_data[i] * input2_data[i];
2264     output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
2265                                                   output_activation_max);
2266   }
2267 }
2268 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2269 inline void Mul(const ArithmeticParams& params,
2270                 const RuntimeShape& input1_shape, const float* input1_data,
2271                 const RuntimeShape& input2_shape, const float* input2_data,
2272                 const RuntimeShape& output_shape, float* output_data) {
2273   ruy::profiler::ScopeLabel label("Mul");
2274 
2275   const int flat_size =
2276       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2277   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2278 }
2279 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2280 inline void Mul(const ArithmeticParams& params,
2281                 const RuntimeShape& input1_shape, const int32* input1_data,
2282                 const RuntimeShape& input2_shape, const int32* input2_data,
2283                 const RuntimeShape& output_shape, int32* output_data) {
2284   ruy::profiler::ScopeLabel label("Mul/int32/activation");
2285 
2286   const int flat_size =
2287       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2288   const int32 output_activation_min = params.quantized_activation_min;
2289   const int32 output_activation_max = params.quantized_activation_max;
2290   for (int i = 0; i < flat_size; ++i) {
2291     output_data[i] = ActivationFunctionWithMinMax(
2292         input1_data[i] * input2_data[i], output_activation_min,
2293         output_activation_max);
2294   }
2295 }
2296 
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2297 inline void MulNoActivation(const ArithmeticParams& params,
2298                             const RuntimeShape& input1_shape,
2299                             const int32* input1_data,
2300                             const RuntimeShape& input2_shape,
2301                             const int32* input2_data,
2302                             const RuntimeShape& output_shape,
2303                             int32* output_data) {
2304   ruy::profiler::ScopeLabel label("Mul/int32");
2305 
2306   auto input1_map = MapAsVector(input1_data, input1_shape);
2307   auto input2_map = MapAsVector(input2_data, input2_shape);
2308   auto output_map = MapAsVector(output_data, output_shape);
2309   if (input1_shape == input2_shape) {
2310     output_map.array() = input1_map.array() * input2_map.array();
2311   } else if (input2_shape.FlatSize() == 1) {
2312     auto scalar = input2_data[0];
2313     output_map.array() = input1_map.array() * scalar;
2314   } else if (input1_shape.FlatSize() == 1) {
2315     auto scalar = input1_data[0];
2316     output_map.array() = scalar * input2_map.array();
2317   } else {
2318     reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
2319                                       input2_shape, input2_data, output_shape,
2320                                       output_data);
2321   }
2322 }
2323 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2324 inline void Mul(const ArithmeticParams& params,
2325                 const RuntimeShape& input1_shape, const int16* input1_data,
2326                 const RuntimeShape& input2_shape, const int16* input2_data,
2327                 const RuntimeShape& output_shape, int16* output_data) {
2328   ruy::profiler::ScopeLabel label("Mul/Int16/NoActivation");
2329   // This is a copy of the reference implementation. We do not currently have a
2330   // properly optimized version.
2331 
2332   const int flat_size =
2333       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2334 
2335   for (int i = 0; i < flat_size; i++) {
2336     // F0 uses 0 integer bits, range [-1, 1].
2337     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2338 
2339     F0 unclamped_result =
2340         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2341     output_data[i] = unclamped_result.raw();
2342   }
2343 }
2344 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2345 inline void Mul(const ArithmeticParams& params,
2346                 const RuntimeShape& input1_shape, const int16* input1_data,
2347                 const RuntimeShape& input2_shape, const int16* input2_data,
2348                 const RuntimeShape& output_shape, uint8* output_data) {
2349   ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
2350   // This is a copy of the reference implementation. We do not currently have a
2351   // properly optimized version.
2352   const int32 output_activation_min = params.quantized_activation_min;
2353   const int32 output_activation_max = params.quantized_activation_max;
2354   const int32 output_offset = params.output_offset;
2355   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2356 
2357   const int flat_size =
2358       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2359 
2360   for (int i = 0; i < flat_size; i++) {
2361     // F0 uses 0 integer bits, range [-1, 1].
2362     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2363 
2364     F0 unclamped_result =
2365         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2366     int16 rescaled_result =
2367         gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
2368     int16 clamped_result =
2369         std::min<int16>(output_activation_max - output_offset, rescaled_result);
2370     clamped_result =
2371         std::max<int16>(output_activation_min - output_offset, clamped_result);
2372     output_data[i] = output_offset + clamped_result;
2373   }
2374 }
2375 
2376 // Element-wise mul that can often be used for inner loop of broadcast Mul as
2377 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2378 inline void MulElementwise(int size, const ArithmeticParams& params,
2379                            const uint8* input1_data, const uint8* input2_data,
2380                            uint8* output_data) {
2381   int i = 0;
2382   TFLITE_DCHECK_GT(params.input1_offset, -256);
2383   TFLITE_DCHECK_LT(params.input1_offset, 256);
2384   TFLITE_DCHECK_GT(params.input2_offset, -256);
2385   TFLITE_DCHECK_LT(params.input2_offset, 256);
2386   TFLITE_DCHECK_GT(params.output_offset, -256);
2387   TFLITE_DCHECK_LT(params.output_offset, 256);
2388 #ifdef USE_NEON
2389   const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
2390   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2391   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2392   const auto output_activation_min_vector =
2393       vdup_n_u8(params.quantized_activation_min);
2394   const auto output_activation_max_vector =
2395       vdup_n_u8(params.quantized_activation_max);
2396   const int left_shift = std::max(0, params.output_shift);
2397   const int right_shift = std::max(0, -params.output_shift);
2398   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2399   for (; i <= size - 8; i += 8) {
2400     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2401     const auto input1_val_original = vld1_u8(input1_data + i);
2402     const auto input2_val_original = vld1_u8(input2_data + i);
2403     const auto input1_val_s16 =
2404         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2405     const auto input2_val_s16 =
2406         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2407     const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
2408     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2409 
2410     const auto input1_val_low = vget_low_s16(input1_val);
2411     const auto input1_val_high = vget_high_s16(input1_val);
2412     const auto input2_val_low = vget_low_s16(input2_val);
2413     const auto input2_val_high = vget_high_s16(input2_val);
2414 
2415     auto p1 = vmull_s16(input2_val_low, input1_val_low);
2416     auto p2 = vmull_s16(input2_val_high, input1_val_high);
2417 
2418     p1 = vshlq_s32(p1, left_shift_vec);
2419     p2 = vshlq_s32(p2, left_shift_vec);
2420     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2421     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2422     using gemmlowp::RoundingDivideByPOT;
2423     p1 = RoundingDivideByPOT(p1, right_shift);
2424     p2 = RoundingDivideByPOT(p2, right_shift);
2425 
2426     const auto p1_narrowed = vqmovn_s32(p1);
2427     const auto p2_narrowed = vqmovn_s32(p2);
2428     const auto p =
2429         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2430     const auto clamped =
2431         vmax_u8(output_activation_min_vector,
2432                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2433     vst1_u8(output_data + i, clamped);
2434   }
2435 #endif  // NEON
2436 
2437   for (; i < size; ++i) {
2438     const int32 input1_val = params.input1_offset + input1_data[i];
2439     const int32 input2_val = params.input2_offset + input2_data[i];
2440     const int32 unclamped_result =
2441         params.output_offset +
2442         MultiplyByQuantizedMultiplier(input1_val * input2_val,
2443                                       params.output_multiplier,
2444                                       params.output_shift);
2445     const int32 clamped_output =
2446         std::min(params.quantized_activation_max,
2447                  std::max(params.quantized_activation_min, unclamped_result));
2448     output_data[i] = static_cast<uint8>(clamped_output);
2449   }
2450 }
2451 
2452 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)2453 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2454                                const uint8 broadcast_value,
2455                                const uint8* input2_data, uint8* output_data) {
2456   const int16 input1_val = params.input1_offset + broadcast_value;
2457 
2458   int i = 0;
2459   TFLITE_DCHECK_GT(params.input1_offset, -256);
2460   TFLITE_DCHECK_LT(params.input1_offset, 256);
2461   TFLITE_DCHECK_GT(params.input2_offset, -256);
2462   TFLITE_DCHECK_LT(params.input2_offset, 256);
2463   TFLITE_DCHECK_GT(params.output_offset, -256);
2464   TFLITE_DCHECK_LT(params.output_offset, 256);
2465 #ifdef USE_NEON
2466   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2467   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2468   const auto output_activation_min_vector =
2469       vdup_n_u8(params.quantized_activation_min);
2470   const auto output_activation_max_vector =
2471       vdup_n_u8(params.quantized_activation_max);
2472   const int left_shift = std::max(0, params.output_shift);
2473   const int right_shift = std::max(0, -params.output_shift);
2474   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2475   for (; i <= size - 8; i += 8) {
2476     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2477     const auto input2_val_original = vld1_u8(input2_data + i);
2478     const auto input2_val_s16 =
2479         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2480     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2481 
2482     const auto input2_val_low = vget_low_s16(input2_val);
2483     const auto input2_val_high = vget_high_s16(input2_val);
2484 
2485     auto p1 = vmull_n_s16(input2_val_low, input1_val);
2486     auto p2 = vmull_n_s16(input2_val_high, input1_val);
2487 
2488     p1 = vshlq_s32(p1, left_shift_vec);
2489     p2 = vshlq_s32(p2, left_shift_vec);
2490     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2491     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2492     using gemmlowp::RoundingDivideByPOT;
2493     p1 = RoundingDivideByPOT(p1, right_shift);
2494     p2 = RoundingDivideByPOT(p2, right_shift);
2495 
2496     const auto p1_narrowed = vmovn_s32(p1);
2497     const auto p2_narrowed = vmovn_s32(p2);
2498     const auto p =
2499         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2500     const auto clamped =
2501         vmax_u8(output_activation_min_vector,
2502                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2503     vst1_u8(output_data + i, clamped);
2504   }
2505 #endif  // NEON
2506 
2507   for (; i < size; ++i) {
2508     const int32 input2_val = params.input2_offset + input2_data[i];
2509     const int32 unclamped_result =
2510         params.output_offset +
2511         MultiplyByQuantizedMultiplier(input1_val * input2_val,
2512                                       params.output_multiplier,
2513                                       params.output_shift);
2514     const int32 clamped_output =
2515         std::min(params.quantized_activation_max,
2516                  std::max(params.quantized_activation_min, unclamped_result));
2517     output_data[i] = static_cast<uint8>(clamped_output);
2518   }
2519 }
2520 
2521 // Broadcast mul that can often be used for inner loop of broadcast Mul.
2522 // This function will handle scalar_value (LHS) * vector_values (RHS).
2523 // Since it's a float function, input params does not matter here.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const float broadcast_value,const float * input2_data,float * output_data)2524 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2525                                const float broadcast_value,
2526                                const float* input2_data, float* output_data) {
2527   int i = 0;
2528 #ifdef USE_NEON
2529   const float32x4_t output_activation_min_vector =
2530       vdupq_n_f32(params.float_activation_min);
2531   const float32x4_t output_activation_max_vector =
2532       vdupq_n_f32(params.float_activation_max);
2533   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2534   for (; i <= size - 4; i += 4) {
2535     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2536 
2537     const float32x4_t output =
2538         vmulq_f32(input2_val_original, broadcast_value_dup);
2539 
2540     const float32x4_t clamped =
2541         vmaxq_f32(output_activation_min_vector,
2542                   vminq_f32(output_activation_max_vector, output));
2543     vst1q_f32(output_data + i, clamped);
2544   }
2545 #endif  // NEON
2546 
2547   for (; i < size; ++i) {
2548     float x = broadcast_value * input2_data[i];
2549     output_data[i] = ActivationFunctionWithMinMax(
2550         x, params.float_activation_min, params.float_activation_max);
2551   }
2552 }
2553 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2554 inline void Mul(const ArithmeticParams& params,
2555                 const RuntimeShape& input1_shape, const uint8* input1_data,
2556                 const RuntimeShape& input2_shape, const uint8* input2_data,
2557                 const RuntimeShape& output_shape, uint8* output_data) {
2558   TFLITE_DCHECK_LE(params.quantized_activation_min,
2559                    params.quantized_activation_max);
2560   ruy::profiler::ScopeLabel label("Mul/8bit");
2561   const int flat_size =
2562       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2563 
2564   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2565 }
2566 
2567 template <typename T>
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2568 inline void BroadcastMulDispatch(
2569     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2570     const T* input1_data, const RuntimeShape& input2_shape,
2571     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2572   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2573     return BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
2574                               input2_data, output_shape, output_data);
2575   }
2576 
2577   BinaryBroadcastFiveFold(
2578       params, input1_shape, input1_data, input2_shape, input2_data,
2579       output_shape, output_data,
2580       static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2581                            T*)>(MulElementwise),
2582       static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2583           MulSimpleBroadcast));
2584 }
2585 
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2586 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
2587                                  const RuntimeShape& unswitched_input1_shape,
2588                                  const uint8* unswitched_input1_data,
2589                                  const RuntimeShape& unswitched_input2_shape,
2590                                  const uint8* unswitched_input2_data,
2591                                  const RuntimeShape& output_shape,
2592                                  uint8* output_data) {
2593   BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
2594                        unswitched_input1_data, unswitched_input2_shape,
2595                        unswitched_input2_data, output_shape, output_data);
2596 }
2597 
BroadcastMulFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2598 inline void BroadcastMulFivefold(const ArithmeticParams& params,
2599                                  const RuntimeShape& unswitched_input1_shape,
2600                                  const float* unswitched_input1_data,
2601                                  const RuntimeShape& unswitched_input2_shape,
2602                                  const float* unswitched_input2_data,
2603                                  const RuntimeShape& output_shape,
2604                                  float* output_data) {
2605   BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2606                        unswitched_input2_shape, unswitched_input2_data,
2607                        output_shape, output_data);
2608 }
2609 
2610 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
2611 // dimensionality if the runtime code does a single loop over one dimension
2612 // that handles broadcasting as the base case. The code generator would then
2613 // generate max(D1, D2) nested for loops.
2614 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
2615 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
2616 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
2617 // reference_ops.h.
2618 template <typename T, int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)2619 void BroadcastDivSlow(const ArithmeticParams& params,
2620                       const RuntimeShape& unextended_input1_shape,
2621                       const T* input1_data,
2622                       const RuntimeShape& unextended_input2_shape,
2623                       const T* input2_data,
2624                       const RuntimeShape& unextended_output_shape,
2625                       T* output_data) {
2626   ruy::profiler::ScopeLabel label("BroadcastDivSlow");
2627   T output_activation_min;
2628   T output_activation_max;
2629   GetActivationParams(params, &output_activation_min, &output_activation_max);
2630 
2631   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2632   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2633   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2634 
2635   NdArrayDesc<N> desc1;
2636   NdArrayDesc<N> desc2;
2637   NdArrayDesc<N> output_desc;
2638   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2639                                       unextended_input2_shape, &desc1, &desc2);
2640   CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2641                  &output_desc);
2642 
2643   // In Tensorflow, the dimensions are canonically named (batch_number, row,
2644   // col, channel), with extents (batches, height, width, depth), with the
2645   // trailing dimension changing most rapidly (channels has the smallest stride,
2646   // typically 1 element).
2647   //
2648   // In generated C code, we store arrays with the dimensions reversed. The
2649   // first dimension has smallest stride.
2650   //
2651   // We name our variables by their Tensorflow convention, but generate C code
2652   // nesting loops such that the innermost loop has the smallest stride for the
2653   // best cache behavior.
2654   auto div_func = [&](int indexes[N]) {
2655     output_data[SubscriptToIndex(output_desc, indexes)] =
2656         ActivationFunctionWithMinMax(
2657             input1_data[SubscriptToIndex(desc1, indexes)] /
2658                 input2_data[SubscriptToIndex(desc2, indexes)],
2659             output_activation_min, output_activation_max);
2660   };
2661   NDOpsHelper<N>(output_desc, div_func);
2662 }
2663 
2664 // TODO: BroadcastDiv is intentionally duplicated from reference_ops.h.
2665 // For more details see the comment above the generic version of
2666 // BroadcastDivSlow.
2667 template <int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const uint8 * input1_data,const RuntimeShape & unextended_input2_shape,const uint8 * input2_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)2668 inline void BroadcastDivSlow(const ArithmeticParams& params,
2669                              const RuntimeShape& unextended_input1_shape,
2670                              const uint8* input1_data,
2671                              const RuntimeShape& unextended_input2_shape,
2672                              const uint8* input2_data,
2673                              const RuntimeShape& unextended_output_shape,
2674                              uint8* output_data) {
2675   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2676   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2677   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2678 
2679   NdArrayDesc<N> desc1;
2680   NdArrayDesc<N> desc2;
2681   NdArrayDesc<N> output_desc;
2682   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2683                                       unextended_input2_shape, &desc1, &desc2);
2684   CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2685                  &output_desc);
2686 
2687   TFLITE_DCHECK_GT(params.input1_offset, -256);
2688   TFLITE_DCHECK_LT(params.input1_offset, 256);
2689   TFLITE_DCHECK_GT(params.input2_offset, -256);
2690   TFLITE_DCHECK_LT(params.input2_offset, 256);
2691   TFLITE_DCHECK_GT(params.output_offset, -256);
2692   TFLITE_DCHECK_LT(params.output_offset, 256);
2693 
2694   auto div_func = [&](int indexes[N]) {
2695     const int32 input1_val =
2696         params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
2697     const int32 input2_val =
2698         params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
2699     TFLITE_DCHECK_NE(input2_val, 0);
2700     int recip_shift;
2701     const int32 input2_inv =
2702         (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
2703                          : -GetReciprocal(-input2_val, 31, &recip_shift);
2704     const int headroom = CountLeadingSignBits(input1_val);
2705     const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
2706         input1_val, input2_inv, headroom);
2707     const int total_shift = params.output_shift - recip_shift - headroom;
2708     const int32 unclamped_result =
2709         params.output_offset +
2710         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2711             unscaled_quotient, params.output_multiplier, total_shift);
2712     const int32 clamped_output =
2713         std::min(params.quantized_activation_max,
2714                  std::max(params.quantized_activation_min, unclamped_result));
2715     output_data[SubscriptToIndex(output_desc, indexes)] =
2716         static_cast<uint8>(clamped_output);
2717   };
2718   NDOpsHelper<N>(output_desc, div_func);
2719 }
2720 
2721 template <typename T>
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2722 inline void SubWithActivation(
2723     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2724     const T* input1_data, const RuntimeShape& input2_shape,
2725     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2726   ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
2727   TFLITE_DCHECK_EQ(input1_shape.FlatSize(), input2_shape.FlatSize());
2728   auto input1_map = MapAsVector(input1_data, input1_shape);
2729   auto input2_map = MapAsVector(input2_data, input2_shape);
2730   auto output_map = MapAsVector(output_data, output_shape);
2731   T activation_min, activation_max;
2732   GetActivationParams(params, &activation_min, &activation_max);
2733   output_map.array() = (input1_map.array() - input2_map.array())
2734                            .cwiseMin(activation_max)
2735                            .cwiseMax(activation_min);
2736 }
2737 
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2738 inline void SubNonBroadcast(const ArithmeticParams& params,
2739                             const RuntimeShape& input1_shape,
2740                             const float* input1_data,
2741                             const RuntimeShape& input2_shape,
2742                             const float* input2_data,
2743                             const RuntimeShape& output_shape,
2744                             float* output_data) {
2745   ruy::profiler::ScopeLabel label("SubNonBroadcast");
2746   SubWithActivation<float>(params, input1_shape, input1_data, input2_shape,
2747                            input2_data, output_shape, output_data);
2748 }
2749 
2750 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2751 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
2752          const T* input1_data, const RuntimeShape& input2_shape,
2753          const T* input2_data, const RuntimeShape& output_shape,
2754          T* output_data) {
2755   ruy::profiler::ScopeLabel label("Sub");
2756 
2757   auto input1_map = MapAsVector(input1_data, input1_shape);
2758   auto input2_map = MapAsVector(input2_data, input2_shape);
2759   auto output_map = MapAsVector(output_data, output_shape);
2760   if (input1_shape == input2_shape) {
2761     output_map.array() = input1_map.array() - input2_map.array();
2762   } else if (input1_shape.FlatSize() == 1) {
2763     auto scalar = input1_data[0];
2764     output_map.array() = scalar - input2_map.array();
2765   } else if (input2_shape.FlatSize() == 1) {
2766     auto scalar = input2_data[0];
2767     output_map.array() = input1_map.array() - scalar;
2768   } else {
2769     BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
2770                      input2_data, output_shape, output_data);
2771   }
2772 }
2773 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data,CpuBackendContext * cpu_backend_context)2774 inline void LstmCell(
2775     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2776     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2777     const float* prev_activ_data, const RuntimeShape& weights_shape,
2778     const float* weights_data, const RuntimeShape& unextended_bias_shape,
2779     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2780     const float* prev_state_data,
2781     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2782     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2783     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2784     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
2785     CpuBackendContext* cpu_backend_context) {
2786   ruy::profiler::ScopeLabel label("LstmCell");
2787   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2788   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2789   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2790   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2791   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2792   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2793   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2794   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2795   const RuntimeShape input_shape =
2796       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2797   const RuntimeShape prev_activ_shape =
2798       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2799   const RuntimeShape bias_shape =
2800       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2801   const RuntimeShape prev_state_shape =
2802       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2803   const RuntimeShape output_state_shape =
2804       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2805   const RuntimeShape output_activ_shape =
2806       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2807   const RuntimeShape concat_temp_shape =
2808       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2809   const RuntimeShape activ_temp_shape =
2810       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2811   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2812 
2813   const int weights_dim_count = weights_shape.DimensionsCount();
2814   MatchingDim(  // batches
2815       input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
2816       output_state_shape, 0, output_activ_shape, 0);
2817   MatchingDim(  // height
2818       input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
2819       output_state_shape, 1, output_activ_shape, 1);
2820   MatchingDim(  // width
2821       input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
2822       output_state_shape, 2, output_activ_shape, 2);
2823   const int input_depth = input_shape.Dims(3);
2824   const int prev_activ_depth = prev_activ_shape.Dims(3);
2825   const int total_input_depth = prev_activ_depth + input_depth;
2826   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2827                    total_input_depth);
2828   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2829   const int intern_activ_depth =
2830       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2831   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2832                    intern_activ_depth * total_input_depth);
2833   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2834   const int output_depth =
2835       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2836                   3, output_activ_shape, 3);
2837   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2838 
2839   // Concatenate prev_activ and input data together
2840   std::vector<float const*> concat_input_arrays_data;
2841   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
2842   concat_input_arrays_data.push_back(input_data);
2843   concat_input_arrays_data.push_back(prev_activ_data);
2844   concat_input_arrays_shapes.push_back(&input_shape);
2845   concat_input_arrays_shapes.push_back(&prev_activ_shape);
2846   tflite::ConcatenationParams concat_params;
2847   concat_params.axis = 3;
2848   concat_params.inputs_count = concat_input_arrays_data.size();
2849   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
2850                 &(concat_input_arrays_data[0]), concat_temp_shape,
2851                 concat_temp_data);
2852 
2853   // Fully connected
2854   tflite::FullyConnectedParams fc_params;
2855   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
2856   fc_params.float_activation_max = std::numeric_limits<float>::max();
2857   fc_params.lhs_cacheable = false;
2858   fc_params.rhs_cacheable = false;
2859   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
2860                  weights_data, bias_shape, bias_data, activ_temp_shape,
2861                  activ_temp_data, cpu_backend_context);
2862 
2863   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
2864   // operations.
2865   ArrayMap<float> activ_temp_map =
2866       MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
2867   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
2868                                             activ_temp_map.cols());
2869   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
2870                                            activ_temp_map.cols());
2871   auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
2872                                              activ_temp_map.cols());
2873   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
2874                                              activ_temp_map.cols());
2875   ArrayMap<const float> prev_state_map =
2876       MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
2877   ArrayMap<float> output_state_map =
2878       MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
2879   ArrayMap<float> output_activ_map =
2880       MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
2881 
2882   // Combined memory state and final output calculation
2883   ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
2884   output_state_map =
2885       input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2886           new_input_sm.tanh() +
2887       forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2888           prev_state_map;
2889   output_activ_map =
2890       output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2891       output_state_map.tanh();
2892 }
2893 
2894 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,CpuBackendContext * cpu_backend_context)2895 inline void LstmCell(
2896     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2897     const uint8* input_data_uint8,
2898     const RuntimeShape& unextended_prev_activ_shape,
2899     const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
2900     const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
2901     const int32* bias_data_int32,
2902     const RuntimeShape& unextended_prev_state_shape,
2903     const int16* prev_state_data_int16,
2904     const RuntimeShape& unextended_output_state_shape,
2905     int16* output_state_data_int16,
2906     const RuntimeShape& unextended_output_activ_shape,
2907     uint8* output_activ_data_uint8,
2908     const RuntimeShape& unextended_concat_temp_shape,
2909     uint8* concat_temp_data_uint8,
2910     const RuntimeShape& unextended_activ_temp_shape,
2911     int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
2912   ruy::profiler::ScopeLabel label(
2913       "LstmCell/quantized (8bit external, 16bit internal)");
2914   int32 weights_zero_point = params.weights_zero_point;
2915   int32 accum_multiplier = params.accum_multiplier;
2916   int accum_shift = params.accum_shift;
2917   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2918   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2919   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2920   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2921   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2922   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2923   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2924   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2925   const RuntimeShape input_shape =
2926       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2927   const RuntimeShape prev_activ_shape =
2928       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2929   const RuntimeShape bias_shape =
2930       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2931   const RuntimeShape prev_state_shape =
2932       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2933   const RuntimeShape output_state_shape =
2934       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2935   const RuntimeShape output_activ_shape =
2936       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2937   const RuntimeShape concat_temp_shape =
2938       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2939   const RuntimeShape activ_temp_shape =
2940       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2941   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2942 
2943   // Gather dimensions information, and perform consistency checks.
2944   const int weights_dim_count = weights_shape.DimensionsCount();
2945   const int outer_size = MatchingFlatSizeSkipDim(
2946       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
2947       output_activ_shape);
2948   const int input_depth = input_shape.Dims(3);
2949   const int prev_activ_depth = prev_activ_shape.Dims(3);
2950   const int total_input_depth = prev_activ_depth + input_depth;
2951   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2952                    total_input_depth);
2953   const int intern_activ_depth =
2954       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2955   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2956                    intern_activ_depth * total_input_depth);
2957   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2958   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2959   const int output_depth =
2960       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2961                   3, output_activ_shape, 3);
2962   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2963   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
2964   const int fc_output_depth =
2965       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
2966   const int fc_accum_depth = total_input_depth;
2967   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
2968 
2969   // Depth-concatenate prev_activ and input data together.
2970   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
2971                                               prev_activ_data_uint8};
2972   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
2973                                                        &prev_activ_shape};
2974   tflite::ConcatenationParams concat_params;
2975   concat_params.axis = 3;
2976   concat_params.inputs_count = 2;
2977   Concatenation(concat_params, concat_input_arrays_shapes,
2978                 concat_input_arrays_data, concat_temp_shape,
2979                 concat_temp_data_uint8);
2980 
2981   // Implementation of the fully connected node inside the LSTM cell.
2982   // The operands are 8-bit integers, the accumulators are internally 32bit
2983   // integers, and the output is 16-bit fixed-point with 3 integer bits so
2984   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
2985   // is explained in the function comment above.
2986   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
2987   lhs_params.rows = fc_output_depth;
2988   lhs_params.cols = fc_accum_depth;
2989   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
2990   lhs_params.zero_point = weights_zero_point;
2991   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
2992   rhs_params.rows = fc_accum_depth;
2993   rhs_params.cols = fc_batches;
2994   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
2995   rhs_params.zero_point = 128;
2996   cpu_backend_gemm::MatrixParams<int16> dst_params;
2997   dst_params.rows = fc_output_depth;
2998   dst_params.cols = fc_batches;
2999   dst_params.order = cpu_backend_gemm::Order::kColMajor;
3000   dst_params.zero_point = 0;
3001   cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
3002   gemm_params.bias = bias_data_int32;
3003   gemm_params.multiplier_fixedpoint = accum_multiplier;
3004   gemm_params.multiplier_exponent = accum_shift;
3005   cpu_backend_gemm::Gemm(
3006       lhs_params, weights_data_uint8, rhs_params, concat_temp_data_uint8,
3007       dst_params, activ_temp_data_int16, gemm_params, cpu_backend_context);
3008 
3009   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3010   // and muls, all done in 16-bit fixed-point.
3011   const int16* input_gate_input_ptr = activ_temp_data_int16;
3012   const int16* input_modulation_gate_input_ptr =
3013       activ_temp_data_int16 + output_depth;
3014   const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3015   const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3016   const int16* prev_state_ptr = prev_state_data_int16;
3017   int16* output_state_data_ptr = output_state_data_int16;
3018   uint8* output_activ_data_ptr = output_activ_data_uint8;
3019 
3020   for (int b = 0; b < outer_size; ++b) {
3021     int c = 0;
3022 #ifdef GEMMLOWP_NEON
3023     for (; c <= output_depth - 8; c += 8) {
3024       // Define the fixed-point data types that we will use here. All use
3025       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3026       // They only differ by the number of integral vs. fractional bits,
3027       // determining the range of values that they can represent.
3028       //
3029       // F0 uses 0 integer bits, range [-1, 1].
3030       // This is the return type of math functions such as tanh, logistic,
3031       // whose range is in [-1, 1].
3032       using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3033       // F3 uses 3 integer bits, range [-8, 8].
3034       // This is the range of the previous fully-connected node's output,
3035       // which is our input here.
3036       using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3037       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3038       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3039       // number of integer bits is currently dictated by the model. See comment
3040       // on the StateIntegerBits template parameter above.
3041       using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3042       // Implementation of input gate, using fixed-point logistic function.
3043       F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3044       input_gate_input_ptr += 8;
3045       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3046       // Implementation of input modulation gate, using fixed-point tanh
3047       // function.
3048       F3 input_modulation_gate_input =
3049           F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3050       input_modulation_gate_input_ptr += 8;
3051       F0 input_modulation_gate_output =
3052           gemmlowp::tanh(input_modulation_gate_input);
3053       // Implementation of forget gate, using fixed-point logistic function.
3054       F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3055       forget_gate_input_ptr += 8;
3056       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3057       // Implementation of output gate, using fixed-point logistic function.
3058       F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3059       output_gate_input_ptr += 8;
3060       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3061       // Implementation of internal multiplication nodes, still in fixed-point.
3062       F0 input_times_input_modulation =
3063           input_gate_output * input_modulation_gate_output;
3064       FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3065       prev_state_ptr += 8;
3066       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3067       // Implementation of internal addition node, saturating.
3068       FS new_state = gemmlowp::SaturatingAdd(
3069           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3070           prev_state_times_forget_state);
3071       // Implementation of last internal Tanh node, still in fixed-point.
3072       // Since a Tanh fixed-point implementation is specialized for a given
3073       // number or integer bits, and each specialization can have a substantial
3074       // code size, and we already used above a Tanh on an input with 3 integer
3075       // bits, and per the table in the above function comment there is no
3076       // significant accuracy to be lost by clamping to [-8, +8] for a
3077       // 3-integer-bits representation, let us just do that. This helps people
3078       // porting this to targets where code footprint must be minimized.
3079       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3080       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3081       // Store the new internal state back to memory, as 16-bit integers.
3082       // Note: here we store the original value with StateIntegerBits, not
3083       // the rescaled 3-integer-bits value fed to tanh.
3084       vst1q_s16(output_state_data_ptr, new_state.raw());
3085       output_state_data_ptr += 8;
3086       // Down-scale the output activations to 8-bit integers, saturating,
3087       // and store back to memory.
3088       int16x8_t rescaled_output_activ =
3089           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3090       int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3091       uint8x8_t uint8_output_activ =
3092           vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3093       vst1_u8(output_activ_data_ptr, uint8_output_activ);
3094       output_activ_data_ptr += 8;
3095     }
3096 #endif
3097     for (; c < output_depth; ++c) {
3098       // Define the fixed-point data types that we will use here. All use
3099       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3100       // They only differ by the number of integral vs. fractional bits,
3101       // determining the range of values that they can represent.
3102       //
3103       // F0 uses 0 integer bits, range [-1, 1].
3104       // This is the return type of math functions such as tanh, logistic,
3105       // whose range is in [-1, 1].
3106       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3107       // F3 uses 3 integer bits, range [-8, 8].
3108       // This is the range of the previous fully-connected node's output,
3109       // which is our input here.
3110       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3111       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3112       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3113       // number of integer bits is currently dictated by the model. See comment
3114       // on the StateIntegerBits template parameter above.
3115       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3116       // Implementation of input gate, using fixed-point logistic function.
3117       F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3118       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3119       // Implementation of input modulation gate, using fixed-point tanh
3120       // function.
3121       F3 input_modulation_gate_input =
3122           F3::FromRaw(*input_modulation_gate_input_ptr++);
3123       F0 input_modulation_gate_output =
3124           gemmlowp::tanh(input_modulation_gate_input);
3125       // Implementation of forget gate, using fixed-point logistic function.
3126       F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3127       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3128       // Implementation of output gate, using fixed-point logistic function.
3129       F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3130       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3131       // Implementation of internal multiplication nodes, still in fixed-point.
3132       F0 input_times_input_modulation =
3133           input_gate_output * input_modulation_gate_output;
3134       FS prev_state = FS::FromRaw(*prev_state_ptr++);
3135       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3136       // Implementation of internal addition node, saturating.
3137       FS new_state = gemmlowp::SaturatingAdd(
3138           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3139           prev_state_times_forget_state);
3140       // Implementation of last internal Tanh node, still in fixed-point.
3141       // Since a Tanh fixed-point implementation is specialized for a given
3142       // number or integer bits, and each specialization can have a substantial
3143       // code size, and we already used above a Tanh on an input with 3 integer
3144       // bits, and per the table in the above function comment there is no
3145       // significant accuracy to be lost by clamping to [-8, +8] for a
3146       // 3-integer-bits representation, let us just do that. This helps people
3147       // porting this to targets where code footprint must be minimized.
3148       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3149       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3150       // Store the new internal state back to memory, as 16-bit integers.
3151       // Note: here we store the original value with StateIntegerBits, not
3152       // the rescaled 3-integer-bits value fed to tanh.
3153       *output_state_data_ptr++ = new_state.raw();
3154       // Down-scale the output activations to 8-bit integers, saturating,
3155       // and store back to memory.
3156       int16 rescaled_output_activ =
3157           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3158       int16 clamped_output_activ =
3159           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3160       *output_activ_data_ptr++ = 128 + clamped_output_activ;
3161     }
3162     input_gate_input_ptr += 3 * output_depth;
3163     input_modulation_gate_input_ptr += 3 * output_depth;
3164     forget_gate_input_ptr += 3 * output_depth;
3165     output_gate_input_ptr += 3 * output_depth;
3166   }
3167 }
3168 
NodeOffset(int b,int h,int w,int height,int width)3169 inline int NodeOffset(int b, int h, int w, int height, int width) {
3170   return (b * height + h) * width + w;
3171 }
3172 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3173 inline void AveragePool(const PoolParams& params,
3174                         const RuntimeShape& input_shape,
3175                         const float* input_data,
3176                         const RuntimeShape& output_shape, float* output_data) {
3177   ruy::profiler::ScopeLabel label("AveragePool");
3178   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3179   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3180   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3181   const int input_height = input_shape.Dims(1);
3182   const int input_width = input_shape.Dims(2);
3183   const int output_height = output_shape.Dims(1);
3184   const int output_width = output_shape.Dims(2);
3185   const int stride_height = params.stride_height;
3186   const int stride_width = params.stride_width;
3187 
3188   // TODO(benoitjacob) make this a proper reference impl without Eigen!
3189   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3190   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3191   // TODO(benoitjacob) get rid of the dynamic memory allocation here!
3192   Eigen::VectorXf out_count(out_mat.cols());
3193   out_count.setZero();
3194   // Prefill the output to 0.
3195   out_mat.setZero();
3196   for (int b = 0; b < batches; ++b) {
3197     for (int h = 0; h < input_height; ++h) {
3198       for (int w = 0; w < input_width; ++w) {
3199         // (h_start, h_end) * (w_start, w_end) is the range that the input
3200         // vector projects to.
3201         int hpad = h + params.padding_values.height;
3202         int wpad = w + params.padding_values.width;
3203         int h_start = (hpad < params.filter_height)
3204                           ? 0
3205                           : (hpad - params.filter_height) / stride_height + 1;
3206         int h_end = std::min(hpad / stride_height + 1, output_height);
3207         int w_start = (wpad < params.filter_width)
3208                           ? 0
3209                           : (wpad - params.filter_width) / stride_width + 1;
3210         int w_end = std::min(wpad / stride_width + 1, output_width);
3211         // compute elementwise sum
3212         for (int ph = h_start; ph < h_end; ++ph) {
3213           for (int pw = w_start; pw < w_end; ++pw) {
3214             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3215             out_mat.col(out_offset) +=
3216                 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
3217             out_count(out_offset)++;
3218           }
3219         }
3220       }
3221     }
3222   }
3223   // Divide the output by the actual number of elements being averaged over
3224   TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
3225   out_mat.array().rowwise() /= out_count.transpose().array();
3226 
3227   const int flat_size = output_shape.FlatSize();
3228   for (int i = 0; i < flat_size; ++i) {
3229     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3230                                                   params.float_activation_min,
3231                                                   params.float_activation_max);
3232   }
3233 }
3234 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3235 inline void AveragePool(const PoolParams& params,
3236                         const RuntimeShape& input_shape,
3237                         const uint8* input_data,
3238                         const RuntimeShape& output_shape, uint8* output_data) {
3239   ruy::profiler::ScopeLabel label("AveragePool/8bit");
3240 
3241   // Here, and in other pooling ops, in order to maintain locality of reference,
3242   // to minimize some recalculations, and to load into NEON vector registers, we
3243   // use an inner loop down the depth. Since depths can be large and hence we
3244   // would need arbitrarily large temporary storage, we divide the work up into
3245   // depth tranches just within the batch loop.
3246   static constexpr int kPoolingAccTrancheSize = 256;
3247 
3248   TFLITE_DCHECK_LE(params.quantized_activation_min,
3249                    params.quantized_activation_max);
3250   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3251   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3252   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3253   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3254   const int input_height = input_shape.Dims(1);
3255   const int input_width = input_shape.Dims(2);
3256   const int output_height = output_shape.Dims(1);
3257   const int output_width = output_shape.Dims(2);
3258   const int stride_height = params.stride_height;
3259   const int stride_width = params.stride_width;
3260 
3261   uint32 acc[kPoolingAccTrancheSize];
3262   for (int batch = 0; batch < batches; ++batch) {
3263     // We proceed through the depth in tranches (see comment above). The
3264     // depth_base is the depth at the beginning of the tranche. The
3265     // tranche_depth is the depth dimension of the tranche.
3266     for (int depth_base = 0; depth_base < depth;
3267          depth_base += kPoolingAccTrancheSize) {
3268       const int tranche_depth =
3269           std::min(depth - depth_base, kPoolingAccTrancheSize);
3270       for (int out_y = 0; out_y < output_height; ++out_y) {
3271         for (int out_x = 0; out_x < output_width; ++out_x) {
3272           const int in_x_origin =
3273               (out_x * stride_width) - params.padding_values.width;
3274           const int in_y_origin =
3275               (out_y * stride_height) - params.padding_values.height;
3276           const int filter_x_start = std::max(0, -in_x_origin);
3277           const int filter_x_end =
3278               std::min(params.filter_width, input_width - in_x_origin);
3279           const int filter_y_start = std::max(0, -in_y_origin);
3280           const int filter_y_end =
3281               std::min(params.filter_height, input_height - in_y_origin);
3282           const int filter_count =
3283               (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
3284           memset(acc, 0, tranche_depth * sizeof(acc[0]));
3285           const uint8* input_ptr =
3286               input_data + depth_base +
3287               depth * (in_x_origin +
3288                        input_width * (in_y_origin + input_height * batch));
3289           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3290             const uint8* input_row_ptr =
3291                 input_ptr + depth * (fy * input_width + filter_x_start);
3292             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3293               const uint8* input_channel_ptr = input_row_ptr;
3294               int channel = 0;
3295 #ifdef USE_NEON
3296               for (; channel <= tranche_depth - 16; channel += 16) {
3297                 uint16x4_t acc_reg[4];
3298                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3299                 input_channel_ptr += 16;
3300                 acc_reg[0] = vget_low_u16(vmovl_u8(vget_low_u8(input_reg)));
3301                 acc_reg[1] = vget_high_u16(vmovl_u8(vget_low_u8(input_reg)));
3302                 acc_reg[2] = vget_low_u16(vmovl_u8(vget_high_u8(input_reg)));
3303                 acc_reg[3] = vget_high_u16(vmovl_u8(vget_high_u8(input_reg)));
3304                 for (int i = 0; i < 4; i++) {
3305                   vst1q_u32(
3306                       acc + channel + 4 * i,
3307                       vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3308                 }
3309               }
3310               for (; channel <= tranche_depth - 8; channel += 8) {
3311                 uint16x4_t acc_reg[2];
3312                 uint16x8_t input_reg = vmovl_u8(vld1_u8(input_channel_ptr));
3313                 input_channel_ptr += 8;
3314                 acc_reg[0] = vget_low_u16(input_reg);
3315                 acc_reg[1] = vget_high_u16(input_reg);
3316                 for (int i = 0; i < 2; i++) {
3317                   vst1q_u32(
3318                       acc + channel + 4 * i,
3319                       vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3320                 }
3321               }
3322 #endif
3323               for (; channel < tranche_depth; ++channel) {
3324                 acc[channel] += *input_channel_ptr++;
3325               }
3326               input_row_ptr += depth;
3327             }
3328           }
3329           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3330                                                    out_x, depth_base);
3331           int channel = 0;
3332 #ifdef USE_NEON
3333 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT)                               \
3334   if (filter_count == FILTER_COUNT) {                                   \
3335     for (; channel <= tranche_depth - 8; channel += 8) {                \
3336       uint16 buf[8];                                                    \
3337       for (int i = 0; i < 8; i++) {                                     \
3338         buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT;  \
3339       }                                                                 \
3340       uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));                      \
3341       buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
3342       buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
3343       vst1_u8(output_ptr + channel, buf8);                              \
3344     }                                                                   \
3345   }
3346           AVGPOOL_DIVIDING_BY(9)
3347           AVGPOOL_DIVIDING_BY(15)
3348 #undef AVGPOOL_DIVIDING_BY
3349           for (; channel <= tranche_depth - 8; channel += 8) {
3350             uint16 buf[8];
3351             for (int i = 0; i < 8; i++) {
3352               buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
3353             }
3354             uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
3355             buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
3356             buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
3357             vst1_u8(output_ptr + channel, buf8);
3358           }
3359 #endif
3360           for (; channel < tranche_depth; ++channel) {
3361             uint16 a = (acc[channel] + filter_count / 2) / filter_count;
3362             a = std::max<uint16>(a, params.quantized_activation_min);
3363             a = std::min<uint16>(a, params.quantized_activation_max);
3364             output_ptr[channel] = static_cast<uint8>(a);
3365           }
3366         }
3367       }
3368     }
3369   }
3370 }
3371 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3372 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3373                     const float* input_data, const RuntimeShape& output_shape,
3374                     float* output_data) {
3375   ruy::profiler::ScopeLabel label("MaxPool");
3376   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3377   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3378   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3379   const int input_height = input_shape.Dims(1);
3380   const int input_width = input_shape.Dims(2);
3381   const int output_height = output_shape.Dims(1);
3382   const int output_width = output_shape.Dims(2);
3383   const int stride_height = params.stride_height;
3384   const int stride_width = params.stride_width;
3385 
3386   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3387   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3388   // Prefill the output to minimum representable float value
3389   out_mat.setConstant(std::numeric_limits<float>::lowest());
3390   for (int b = 0; b < batches; ++b) {
3391     for (int h = 0; h < input_height; ++h) {
3392       for (int w = 0; w < input_width; ++w) {
3393         // (h_start, h_end) * (w_start, w_end) is the range that the input
3394         // vector projects to.
3395         int hpad = h + params.padding_values.height;
3396         int wpad = w + params.padding_values.width;
3397         int h_start = (hpad < params.filter_height)
3398                           ? 0
3399                           : (hpad - params.filter_height) / stride_height + 1;
3400         int h_end = std::min(hpad / stride_height + 1, output_height);
3401         int w_start = (wpad < params.filter_width)
3402                           ? 0
3403                           : (wpad - params.filter_width) / stride_width + 1;
3404         int w_end = std::min(wpad / stride_width + 1, output_width);
3405         // compute elementwise sum
3406         for (int ph = h_start; ph < h_end; ++ph) {
3407           for (int pw = w_start; pw < w_end; ++pw) {
3408             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3409             out_mat.col(out_offset) =
3410                 out_mat.col(out_offset)
3411                     .cwiseMax(in_mat.col(
3412                         NodeOffset(b, h, w, input_height, input_width)));
3413           }
3414         }
3415       }
3416     }
3417   }
3418   const int flat_size = output_shape.FlatSize();
3419   for (int i = 0; i < flat_size; ++i) {
3420     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3421                                                   params.float_activation_min,
3422                                                   params.float_activation_max);
3423   }
3424 }
3425 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3426 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3427                     const uint8* input_data, const RuntimeShape& output_shape,
3428                     uint8* output_data) {
3429   ruy::profiler::ScopeLabel label("MaxPool/8bit");
3430 
3431   // Here, and in other pooling ops, in order to maintain locality of reference,
3432   // to minimize some recalculations, and to load into NEON vector registers, we
3433   // use an inner loop down the depth. Since depths can be large and hence we
3434   // would need arbitrarily large temporary storage, we divide the work up into
3435   // depth tranches just within the batch loop.
3436   static constexpr int kPoolingAccTrancheSize = 256;
3437 
3438   TFLITE_DCHECK_LE(params.quantized_activation_min,
3439                    params.quantized_activation_max);
3440   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3441   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3442   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3443   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3444   const int input_height = input_shape.Dims(1);
3445   const int input_width = input_shape.Dims(2);
3446   const int output_height = output_shape.Dims(1);
3447   const int output_width = output_shape.Dims(2);
3448   const int stride_height = params.stride_height;
3449   const int stride_width = params.stride_width;
3450 
3451   uint8 acc[kPoolingAccTrancheSize];
3452   for (int batch = 0; batch < batches; ++batch) {
3453     // We proceed through the depth in tranches (see comment above). The
3454     // depth_base is the depth at the beginning of the tranche. The
3455     // tranche_depth is the depth dimension of the tranche.
3456     for (int depth_base = 0; depth_base < depth;
3457          depth_base += kPoolingAccTrancheSize) {
3458       const int tranche_depth =
3459           std::min(depth - depth_base, kPoolingAccTrancheSize);
3460       for (int out_y = 0; out_y < output_height; ++out_y) {
3461         for (int out_x = 0; out_x < output_width; ++out_x) {
3462           const int in_x_origin =
3463               (out_x * stride_width) - params.padding_values.width;
3464           const int in_y_origin =
3465               (out_y * stride_height) - params.padding_values.height;
3466           const int filter_x_start = std::max(0, -in_x_origin);
3467           const int filter_x_end =
3468               std::min(params.filter_width, input_width - in_x_origin);
3469           const int filter_y_start = std::max(0, -in_y_origin);
3470           const int filter_y_end =
3471               std::min(params.filter_height, input_height - in_y_origin);
3472           memset(acc, 0, tranche_depth * sizeof(acc[0]));
3473           const uint8* input_ptr =
3474               input_data + depth_base +
3475               depth * (in_x_origin +
3476                        input_width * (in_y_origin + input_height * batch));
3477           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3478             const uint8* input_row_ptr =
3479                 input_ptr + depth * (fy * input_width + filter_x_start);
3480             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3481               const uint8* input_channel_ptr = input_row_ptr;
3482               int channel = 0;
3483 #ifdef USE_NEON
3484               for (; channel <= tranche_depth - 16; channel += 16) {
3485                 uint8x16_t acc_reg = vld1q_u8(acc + channel);
3486                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3487                 input_channel_ptr += 16;
3488                 acc_reg = vmaxq_u8(acc_reg, input_reg);
3489                 vst1q_u8(acc + channel, acc_reg);
3490               }
3491 
3492               for (; channel <= tranche_depth - 8; channel += 8) {
3493                 uint8x8_t acc_reg = vld1_u8(acc + channel);
3494                 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
3495                 input_channel_ptr += 8;
3496                 acc_reg = vmax_u8(acc_reg, input_reg);
3497                 vst1_u8(acc + channel, acc_reg);
3498               }
3499 #endif
3500               for (; channel < tranche_depth; ++channel) {
3501                 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
3502               }
3503               input_row_ptr += depth;
3504             }
3505           }
3506           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3507                                                    out_x, depth_base);
3508           int channel = 0;
3509 #ifdef USE_NEON
3510           for (; channel <= tranche_depth - 16; channel += 16) {
3511             uint8x16_t a = vld1q_u8(acc + channel);
3512             a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
3513             a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
3514             vst1q_u8(output_ptr + channel, a);
3515           }
3516           for (; channel <= tranche_depth - 8; channel += 8) {
3517             uint8x8_t a = vld1_u8(acc + channel);
3518             a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
3519             a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
3520             vst1_u8(output_ptr + channel, a);
3521           }
3522 #endif
3523           for (; channel < tranche_depth; ++channel) {
3524             uint8 a = acc[channel];
3525             a = std::max<uint8>(a, params.quantized_activation_min);
3526             a = std::min<uint8>(a, params.quantized_activation_max);
3527             output_ptr[channel] = static_cast<uint8>(a);
3528           }
3529         }
3530       }
3531     }
3532   }
3533 }
3534 
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3535 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
3536                    const float* input_data, const RuntimeShape& output_shape,
3537                    float* output_data) {
3538   ruy::profiler::ScopeLabel label("L2Pool");
3539   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3540   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3541   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3542   const int input_height = input_shape.Dims(1);
3543   const int input_width = input_shape.Dims(2);
3544   const int output_height = output_shape.Dims(1);
3545   const int output_width = output_shape.Dims(2);
3546   const int stride_height = params.stride_height;
3547   const int stride_width = params.stride_width;
3548   // Actually carry out L2 Pool. Code is written in forward mode: we go through
3549   // the input values once, and write to all the pooled regions that it maps to.
3550   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3551   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3552   Eigen::VectorXf in_square(in_mat.rows());
3553   Eigen::VectorXf out_count(out_mat.cols());
3554   out_count.setZero();
3555   // Prefill the output to 0.
3556   out_mat.setZero();
3557   for (int b = 0; b < batches; ++b) {
3558     for (int h = 0; h < input_height; ++h) {
3559       for (int w = 0; w < input_width; ++w) {
3560         // (h_start, h_end) * (w_start, w_end) is the range that the input
3561         // vector projects to.
3562         const int hpad = h + params.padding_values.height;
3563         const int wpad = w + params.padding_values.width;
3564         const int h_start =
3565             (hpad < params.filter_height)
3566                 ? 0
3567                 : (hpad - params.filter_height) / stride_height + 1;
3568         const int h_end = std::min(hpad / stride_height + 1, output_height);
3569         const int w_start =
3570             (wpad < params.filter_width)
3571                 ? 0
3572                 : (wpad - params.filter_width) / stride_width + 1;
3573         const int w_end = std::min(wpad / stride_width + 1, output_width);
3574         // pre-compute square
3575         const int in_offset = w + input_width * (h + input_height * b);
3576         in_square =
3577             in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
3578         // compute elementwise sum of squares
3579         for (int ph = h_start; ph < h_end; ++ph) {
3580           for (int pw = w_start; pw < w_end; ++pw) {
3581             const int out_offset = pw + output_width * (ph + output_height * b);
3582             out_mat.col(out_offset) += in_square;
3583             out_count(out_offset)++;
3584           }
3585         }
3586       }
3587     }
3588   }
3589 
3590   out_count = out_count.array().inverse();
3591   out_mat =
3592       (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
3593 
3594   const int flat_size = output_shape.FlatSize();
3595   for (int i = 0; i < flat_size; ++i) {
3596     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3597                                                   params.float_activation_min,
3598                                                   params.float_activation_max);
3599   }
3600 }
3601 
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3602 inline void LocalResponseNormalization(
3603     const tflite::LocalResponseNormalizationParams& op_params,
3604     const RuntimeShape& input_shape, const float* input_data,
3605     const RuntimeShape& output_shape, float* output_data) {
3606   ruy::profiler::ScopeLabel label("LocalResponseNormalization");
3607   MatchingFlatSize(input_shape, output_shape);
3608 
3609   const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3610   auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3611 
3612   // Carry out local response normalization, vector by vector.
3613   // Since the data are stored column major, making row-wise operation
3614   // probably not memory efficient anyway, we do an explicit for loop over
3615   // the columns.
3616   const int double_range = op_params.range * 2;
3617   Eigen::VectorXf padded_square(data_in.rows() + double_range);
3618   padded_square.setZero();
3619   const float bias = op_params.bias;
3620   for (int r = 0; r < data_in.cols(); ++r) {
3621     // Do local response normalization for data_in(:, r)
3622     // first, compute the square and store them in buffer for repeated use
3623     padded_square.block(op_params.range, 0, data_in.rows(), 1) =
3624         data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
3625     // Then, compute the scale and writes them to data_out
3626     float accumulated_scale = 0;
3627     for (int i = 0; i < double_range; ++i) {
3628       accumulated_scale += padded_square(i);
3629     }
3630     for (int i = 0; i < data_in.rows(); ++i) {
3631       accumulated_scale += padded_square(i + double_range);
3632       data_out(i, r) = bias + accumulated_scale;
3633       accumulated_scale -= padded_square(i);
3634     }
3635   }
3636 
3637   // In a few cases, the pow computation could benefit from speedups.
3638   if (op_params.beta == 1) {
3639     data_out.array() = data_in.array() * data_out.array().inverse();
3640   } else if (op_params.beta == 0.5f) {
3641     data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
3642   } else {
3643     data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
3644   }
3645 }
3646 
SoftmaxImpl(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,int start_batch,int end_batch)3647 inline void SoftmaxImpl(const SoftmaxParams& params,
3648                         const RuntimeShape& input_shape,
3649                         const float* input_data,
3650                         const RuntimeShape& output_shape, float* output_data,
3651                         int start_batch, int end_batch) {
3652   ruy::profiler::ScopeLabel label("Softmax/Impl");
3653   MatchingFlatSize(input_shape, output_shape);
3654 
3655   const int logit_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
3656   const MatrixMap<const float> in_mat(input_data + logit_size * start_batch,
3657                                       logit_size, end_batch - start_batch);
3658   MatrixMap<float> out_mat(output_data + logit_size * start_batch, logit_size,
3659                            end_batch - start_batch);
3660   // Compute the exponential first, removing the max coefficient for numerical
3661   // stability.
3662   out_mat =
3663       (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
3664   // We are separating out the exp function so that exp can be vectorized.
3665   out_mat = out_mat.array().exp();
3666   // Normalize to get the activations.
3667   Eigen::Array<float, 1, Eigen::Dynamic> scale =
3668       out_mat.array().colwise().sum().inverse();
3669   out_mat.array().rowwise() *= scale;
3670 }
3671 
3672 struct SoftmaxWorkerTask : cpu_backend_threadpool::Task {
SoftmaxWorkerTaskSoftmaxWorkerTask3673   SoftmaxWorkerTask(const SoftmaxParams& params,
3674                     const RuntimeShape& input_shape, const float* input_data,
3675                     const RuntimeShape& output_shape, float* output_data,
3676                     int start_batch, int end_batch)
3677       : params(params),
3678         input_shape(input_shape),
3679         input_data(input_data),
3680         output_shape(output_shape),
3681         output_data(output_data),
3682         start_batch(start_batch),
3683         end_batch(end_batch) {}
RunSoftmaxWorkerTask3684   void Run() override {
3685     SoftmaxImpl(params, input_shape, input_data, output_shape, output_data,
3686                 start_batch, end_batch);
3687   }
3688 
3689  private:
3690   const tflite::SoftmaxParams& params;
3691   const RuntimeShape& input_shape;
3692   const float* input_data;
3693   const RuntimeShape& output_shape;
3694   float* output_data;
3695   int start_batch;
3696   int end_batch;
3697 };
3698 
3699 inline void Softmax(const SoftmaxParams& params,
3700                     const RuntimeShape& input_shape, const float* input_data,
3701                     const RuntimeShape& output_shape, float* output_data,
3702                     CpuBackendContext* cpu_backend_context = nullptr) {
3703   ruy::profiler::ScopeLabel label("Softmax");
3704 
3705   // We picture softmax input as a 2-D matrix while the last dim is the logit
3706   // dim, and the rest dims will be the batch dim for the 2-D matrix.
3707   const int batch_size =
3708       FlatSizeSkipDim(input_shape, input_shape.DimensionsCount() - 1);
3709   constexpr int kMinBatchPerThread = 8;
3710   int thread_count = batch_size / kMinBatchPerThread;
3711   thread_count = thread_count > 0 ? thread_count : 1;
3712   const int capped_thread_count =
3713       cpu_backend_context == nullptr
3714           ? 1
3715           : std::min(thread_count, cpu_backend_context->max_num_threads());
3716   if (capped_thread_count == 1) {
3717     SoftmaxImpl(params, input_shape, input_data, output_shape, output_data, 0,
3718                 batch_size);
3719   } else {
3720     std::vector<SoftmaxWorkerTask> tasks;
3721     // TODO(b/131746020) don't create new heap allocations every time.
3722     // At least we make it a single heap allocation by using reserve().
3723     tasks.reserve(capped_thread_count);
3724     int batch_start = 0;
3725     for (int i = 0; i < capped_thread_count; ++i) {
3726       // Try to distribute the tasks as even as possible.
3727       int batch_end =
3728           batch_start + (batch_size - batch_start) / (capped_thread_count - i);
3729       tasks.emplace_back(params, input_shape, input_data, output_shape,
3730                          output_data, batch_start, batch_end);
3731       batch_start = batch_end;
3732     }
3733     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
3734                                     cpu_backend_context);
3735   }
3736 }
3737 
3738 template <typename T>
QuantizeSoftmaxOutput(float prob_rescaled,int32_t zero_point)3739 inline int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point) {
3740   const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
3741   return prob_rnd + zero_point;
3742 }
3743 
3744 #if !__aarch64__
3745 // With ARM64, rounding is faster than add + truncation.
3746 template <>
3747 inline int32_t QuantizeSoftmaxOutput<uint8_t>(float prob_rescaled,
3748                                               int32_t zero_point) {
3749   return static_cast<int32_t>(prob_rescaled + 0.5f);
3750 }
3751 #endif
3752 
PopulateSoftmaxLookupTable(SoftmaxParams * data,float input_scale,float beta)3753 inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
3754                                        float beta) {
3755   const float scale = -input_scale * beta;
3756   const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3757   for (int32_t val = 0; val <= max_uint8; ++val) {
3758     data->table[max_uint8 - val] = expf(scale * val);
3759   }
3760 }
3761 
3762 template <typename In, typename Out>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3763 inline void Softmax(const SoftmaxParams& params,
3764                     const RuntimeShape& input_shape, const In* input_data,
3765                     const RuntimeShape& output_shape, Out* output_data) {
3766   const int trailing_dim = input_shape.DimensionsCount() - 1;
3767   const int excluding_last_dim =
3768       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3769   const int last_dim =
3770       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3771 
3772   const int32_t clamp_max = std::numeric_limits<Out>::max();
3773   const int32_t clamp_min = std::numeric_limits<Out>::min();
3774   for (int i = 0; i < excluding_last_dim; ++i) {
3775     int32_t max_val = std::numeric_limits<In>::min();
3776     // Find max quantized value.
3777     for (int j = 0; j < last_dim; ++j) {
3778       max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
3779     }
3780 
3781     float sum_exp = 0.0f;
3782     const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3783     const float* table_offset = &params.table[max_uint8 - max_val];
3784     // Calculate normalizer sum(exp(x)).
3785     for (int j = 0; j < last_dim; ++j) {
3786       sum_exp += table_offset[input_data[j]];
3787     }
3788 
3789     const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3790     // Normalize and quantize probabilities.
3791     for (int j = 0; j < last_dim; ++j) {
3792       const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
3793       const int32_t prob_quantized =
3794           QuantizeSoftmaxOutput<Out>(prob_rescaled, params.zero_point);
3795       output_data[j] = static_cast<Out>(
3796           std::max(std::min(clamp_max, prob_quantized), clamp_min));
3797     }
3798     input_data += last_dim;
3799     output_data += last_dim;
3800   }
3801 }
3802 
3803 // Here's the softmax LUT optimization strategy:
3804 // For softmax, we can do some mathmetically equivalent transformation:
3805 //
3806 // softmax(x) = e^x / sum(e^x, 0...n)  ===> equals to
3807 // softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
3808 //
3809 // For quantization, `x` in our case is (input_q - input_zp) * input_s
3810 // For uint8 case (int8 can be handled similarly), the range is [0, 255]
3811 //
3812 // so if we let
3813 // CONST = (255 - input_zp) * input_s
3814 // then we will have:
3815 // softmax(x) = e^((input_q - 255) * input_s) --------- (1)
3816 //         /
3817 // sum(e^(input_q - 255) * input_s, 0...n)   -------- (2)
3818 //
3819 // the good thing about (1) is it's within the range of (0, 1), so we can
3820 // approximate its result with uint16.
3821 //  (1) = uint8_out * 1 / 2^16.
3822 //
3823 // so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
3824 // then (2) is essentially the following:
3825 // sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
3826 //
3827 // since (output_q - output_zp) * output_s = softmax(x)
3828 // output_q = lookup_uint8_table(input_zp)
3829 //            /
3830 // (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
3831 //             +
3832 //   output_zp
3833 //
3834 // We can actually further improve the performance by using uint8 instead of
3835 // uint16. But that we may lose some accuracy, so we need to pay attention
3836 // to that.
PopulateSoftmaxUInt8LookupTable(SoftmaxParams * data,float input_scale,float beta)3837 inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
3838                                             float input_scale, float beta) {
3839   const float scale = input_scale * beta;
3840   const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3841   const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
3842 
3843   for (int32_t val = 0; val <= max_uint8; ++val) {
3844     float input_to_exp = scale * (val - max_uint8);
3845     int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
3846     temp = std::min(max_uint16, temp);
3847     uint8_t part1 = temp >> 8;
3848     uint8_t part2 = temp & 0xff;
3849     data->uint8_table1[val] = static_cast<uint8_t>(part1);
3850     data->uint8_table2[val] = static_cast<uint8_t>(part2);
3851   }
3852 }
3853 
FindMaxValue(int size,const uint8_t * input_data,uint8_t offset)3854 inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
3855   int32_t max_val = std::numeric_limits<uint8_t>::min();
3856   int j = 0;
3857 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3858   uint8x16_t max_val_dup = vdupq_n_u8(max_val);
3859   uint8x16_t offset_dup = vdupq_n_u8(offset);
3860   for (; j <= size - 16; j += 16) {
3861     uint8x16_t input_value = vld1q_u8(input_data + j);
3862     input_value = veorq_u8(input_value, offset_dup);
3863     max_val_dup = vmaxq_u8(input_value, max_val_dup);
3864   }
3865   max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
3866 #endif
3867 
3868   for (; j < size; ++j) {
3869     max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
3870   }
3871   return max_val;
3872 }
3873 
3874 #ifdef USE_NEON
3875 // Value_to_store layout:
3876 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,int8_t * output)3877 inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
3878   const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
3879                                           vqmovn_s32(value_to_store.val[0]));
3880   const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
3881                                           vqmovn_s32(value_to_store.val[2]));
3882   const int8x16_t result =
3883       vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
3884   vst1q_s8(output, result);
3885 }
3886 
3887 // Value_to_store layout:
3888 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,uint8_t * output)3889 inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
3890   const uint16x8_t result_1 =
3891       vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
3892                    vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
3893   const uint16x8_t result_2 =
3894       vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
3895                    vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
3896   const uint8x16_t result =
3897       vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
3898   vst1q_u8(output, result);
3899 }
3900 
3901 #endif
3902 
3903 template <typename In, typename Out>
SoftmaxInt8LUT(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3904 inline void SoftmaxInt8LUT(const SoftmaxParams& params,
3905                            const RuntimeShape& input_shape,
3906                            const In* input_data,
3907                            const RuntimeShape& output_shape, Out* output_data) {
3908   const int trailing_dim = input_shape.DimensionsCount() - 1;
3909   const int excluding_last_dim =
3910       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3911   const int last_dim =
3912       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3913 
3914   const int32_t clamp_max = std::numeric_limits<Out>::max();
3915   const int32_t clamp_min = std::numeric_limits<Out>::min();
3916 
3917   // Offset is used to interpret the input data "correctly".
3918   // If the input is uint8, the data will be unchanged.
3919   // If the input is int8, since it will be reinterpret as uint8.
3920   // e.g.,
3921   // int8 127 will be applied "offset" to become 255 in uint8.
3922   uint8_t offset = 0;
3923   if (std::is_same<In, int8>::value) {
3924     offset = 0x80;
3925   }
3926 
3927   const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
3928 
3929 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3930   // This code uses ARM64-only instructions.
3931   // TODO(b/143709993): Port to ARMv7
3932 
3933   // Load the tables into registers. (4*4 128-bit registers)
3934   uint8x16x4_t table1[4];
3935   table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
3936   table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
3937   table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
3938   table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
3939 
3940   uint8x16x4_t table2[4];
3941   table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
3942   table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
3943   table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
3944   table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
3945 #endif
3946 
3947   for (int i = 0; i < excluding_last_dim; ++i) {
3948     // Find max quantized value.
3949     int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
3950 
3951     int32 sum_exp = 0;
3952     const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3953     const uint8_t table_offset = max_uint8 - max_val;
3954 
3955     // Calculate normalizer sum(exp(x)).
3956     int sum_j = 0;
3957 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3958     uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
3959     uint8x16_t offset_dup = vdupq_n_u8(offset);
3960     uint32x4_t sum_4 = vdupq_n_u32(0);
3961     const int multiplier_shift = 8;
3962     for (; sum_j <= last_dim - 16; sum_j += 16) {
3963       uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
3964       input_value = veorq_u8(input_value, offset_dup);
3965       input_value = vaddq_u8(input_value, table_offset_dup);
3966 
3967       const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3968       const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3969 
3970       uint16x8_t exp_value1 =
3971           vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3972       uint16x8_t exp_value2 =
3973           vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3974 
3975       exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3976       exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3977 
3978       sum_4 = vpadalq_u16(sum_4, exp_value1);
3979       sum_4 = vpadalq_u16(sum_4, exp_value2);
3980     }
3981     int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
3982                vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
3983     sum_exp += temp;
3984 
3985 #endif
3986     for (; sum_j < last_dim; ++sum_j) {
3987       const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
3988 
3989       uint8_t part1 = params.uint8_table1[index];
3990       uint8_t part2 = params.uint8_table2[index];
3991       sum_exp += ((part1 << 8) + part2);
3992     }
3993 
3994     const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3995 
3996     int32 multiplier, shift;
3997     QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
3998 
3999     // Normalize and quantize probabilities.
4000     int j = 0;
4001 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
4002     const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
4003     const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
4004     const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
4005 
4006     for (; j <= last_dim - 16; j += 16) {
4007       uint8x16_t input_value = vld1q_u8(input_data_uint + j);
4008       input_value = veorq_u8(input_value, offset_dup);
4009       input_value = vaddq_u8(input_value, table_offset_dup);
4010 
4011       const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
4012       const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
4013 
4014       uint16x8_t exp_value1 =
4015           vshll_n_u8(vget_high_u8(output1), multiplier_shift);
4016       uint16x8_t exp_value2 =
4017           vshll_n_u8(vget_low_u8(output1), multiplier_shift);
4018 
4019       exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
4020       exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
4021 
4022       int32x4x4_t output_value;
4023       output_value.val[0] =
4024           vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
4025       output_value.val[1] =
4026           vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
4027       output_value.val[2] =
4028           vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
4029       output_value.val[3] =
4030           vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
4031 
4032       int32x4x4_t temp_val =
4033           MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
4034 
4035       temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
4036       temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
4037       temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
4038       temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
4039 
4040       temp_val.val[0] =
4041           vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
4042       temp_val.val[1] =
4043           vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
4044       temp_val.val[2] =
4045           vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
4046       temp_val.val[3] =
4047           vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
4048 
4049       StoreValue(temp_val, output_data + j);
4050     }
4051 #endif
4052     for (; j < last_dim; ++j) {
4053       const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
4054       const uint8_t part1 = params.uint8_table1[index];
4055       const uint8_t part2 = params.uint8_table2[index];
4056       const int32_t exp_value = (part1 << 8) + part2;
4057       const int32_t output_value =
4058           MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
4059 
4060       output_data[j] = static_cast<Out>(std::max(
4061           std::min(clamp_max, output_value + params.zero_point), clamp_min));
4062     }
4063     input_data_uint += last_dim;
4064     output_data += last_dim;
4065   }
4066 }
4067 
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4068 inline void LogSoftmax(const SoftmaxParams& params,
4069                        const RuntimeShape& input_shape, const float* input_data,
4070                        const RuntimeShape& output_shape, float* output_data) {
4071   ruy::profiler::ScopeLabel label("LogSoftmax");
4072   const int trailing_dim = input_shape.DimensionsCount() - 1;
4073   const int outer_size =
4074       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4075   const int depth =
4076       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4077 
4078   for (int i = 0; i < outer_size; ++i) {
4079     VectorMap<const float> block_input(input_data + i * depth, depth, 1);
4080     VectorMap<float> block_output(output_data + i * depth, depth, 1);
4081     // Find max element value which we'll use to ensure numerical stability
4082     // taking advantage of the following equality:
4083     // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
4084     const float max = block_input.maxCoeff();
4085     const float log_sum = std::log((block_input.array() - max).exp().sum());
4086     block_output = block_input.array() - max - log_sum;
4087   }
4088 }
4089 
4090 // Backwards compatibility. Less optimized than below version.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4091 inline void LogSoftmax(const SoftmaxParams& params,
4092                        const RuntimeShape& input_shape, const uint8* input_data,
4093                        const RuntimeShape& output_shape, uint8* output_data) {
4094   reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4095                             output_data);
4096 }
4097 
4098 // Compute LogSoftmax as (x - x_max) - ln(sum(e^(x_i - x_max)...)
4099 // as done in tf.nn.log_softmax to prevent underflow and overflow.
4100 // This is in contrast to just log(softmax(x))
4101 //
4102 // To handle quantization, first dequantize the inputs (from doing
4103 // e^(input scale * val) where we ignore the zero point since it cancels
4104 // out during subtraction due to the ln) and do a rescale at the end to int8.
4105 //
4106 // Notably this makes use of float and is intended as the optimized
4107 // form for quantized execution on CPU. For a fully integer version,
4108 // see the reference op.
4109 //
4110 // TODO(tflite): notes for optimization:
4111 // 1) See if e^ is also bottleneck in the reference fully-integer
4112 // version and apply lookup there and compare.
4113 template <typename T>
LogSoftmax(const SoftmaxParams & params,float input_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4114 inline void LogSoftmax(const SoftmaxParams& params, float input_scale,
4115                        const RuntimeShape& input_shape, const T* input_data,
4116                        const RuntimeShape& output_shape, T* output_data) {
4117   ruy::profiler::ScopeLabel label("LogSoftmax");
4118   const int trailing_dim = input_shape.DimensionsCount() - 1;
4119   const int excluding_last_dim =
4120       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4121   const int last_dim =
4122       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4123 
4124   const int32_t clamp_max = std::numeric_limits<T>::max();
4125   const int32_t clamp_min = std::numeric_limits<T>::min();
4126 
4127   int32_t zero_point_offset = 0;
4128   if (std::is_same<T, int8_t>::value) {
4129     zero_point_offset = 128;
4130   }
4131   for (int i = 0; i < excluding_last_dim; ++i) {
4132     T max_val = std::numeric_limits<T>::min();
4133     // Find max quantized value.
4134     for (int j = 0; j < last_dim; ++j) {
4135       max_val = std::max(max_val, input_data[j]);
4136     }
4137 
4138     float sum_exp = 0.0f;
4139     const int32_t max_q8 = std::numeric_limits<T>::max();
4140     // Offset into table to compute exp(scale*(x - xmax)) instead of
4141     // exp(scale*(x)) to prevent overflow.
4142     const float* table_offset = &params.table[max_q8 - max_val];
4143     // Calculate sum(exp(scale*(x - x_max))).
4144     for (int j = 0; j < last_dim; ++j) {
4145       sum_exp += table_offset[input_data[j]];
4146     }
4147     const float log_sum_exp = std::log(sum_exp);
4148 
4149     // params.scale is the output scale.
4150     const float scale = input_scale / params.scale;
4151     const float precomputed =
4152         (input_scale * (max_val + zero_point_offset) + log_sum_exp) /
4153         params.scale;
4154     for (int j = 0; j < last_dim; ++j) {
4155       // Equivalent to (input_scale * (input_data[j] - max_val) - log_sum_exp) /
4156       // output_scale.
4157       const float log_prob = scale * input_data[j] - precomputed;
4158 
4159       // TODO(tflite): look into better solution.
4160       // Use std::rint over std::round (which is used in
4161       // FakeQuant) since it's multiple times faster on tested arm32.
4162       const int32_t prob_quantized = std::rint(log_prob) + params.zero_point;
4163       output_data[j] = static_cast<T>(
4164           std::max(std::min(clamp_max, prob_quantized), clamp_min));
4165     }
4166     input_data += last_dim;
4167     output_data += last_dim;
4168   }
4169 }
4170 
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4171 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
4172                      const RuntimeShape& output_shape, float* output_data) {
4173   ruy::profiler::ScopeLabel label("Logistic");
4174   auto input_map = MapAsVector(input_data, input_shape);
4175   auto output_map = MapAsVector(output_data, output_shape);
4176   output_map.array() =
4177       input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
4178 }
4179 
4180 // Convenience version that allows, for example, generated-code calls to be
4181 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4182 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
4183                      const float* input_data, const RuntimeShape& output_shape,
4184                      float* output_data) {
4185   // Drop params: not needed.
4186   Logistic(input_shape, input_data, output_shape, output_data);
4187 }
4188 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4189 inline void Logistic(const LogisticParams& params,
4190                      const RuntimeShape& input_shape, const int16* input_data,
4191                      const RuntimeShape& output_shape, int16* output_data) {
4192   ruy::profiler::ScopeLabel label("Logistic/Int16");
4193   const int flat_size = MatchingFlatSize(input_shape, output_shape);
4194 
4195   for (int i = 0; i < flat_size; i++) {
4196   }
4197 
4198   int c = 0;
4199   const int16* input_data_ptr = input_data;
4200   int16* output_data_ptr = output_data;
4201 #ifdef GEMMLOWP_NEON
4202   {
4203     // F0 uses 0 integer bits, range [-1, 1].
4204     // This is the return type of math functions such as tanh, logistic,
4205     // whose range is in [-1, 1].
4206     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4207     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4208     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4209 
4210     for (; c <= flat_size - 16; c += 16) {
4211       F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4212       F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4213       F0 output0 = gemmlowp::logistic(input0);
4214       F0 output1 = gemmlowp::logistic(input1);
4215       vst1q_s16(output_data_ptr, output0.raw());
4216       vst1q_s16(output_data_ptr + 8, output1.raw());
4217 
4218       input_data_ptr += 16;
4219       output_data_ptr += 16;
4220     }
4221     for (; c <= flat_size - 8; c += 8) {
4222       F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4223       F0 output = gemmlowp::logistic(input);
4224       vst1q_s16(output_data_ptr, output.raw());
4225 
4226       input_data_ptr += 8;
4227       output_data_ptr += 8;
4228     }
4229   }
4230 #endif
4231 #ifdef GEMMLOWP_SSE4
4232   {
4233     // F0 uses 0 integer bits, range [-1, 1].
4234     // This is the return type of math functions such as tanh, logistic,
4235     // whose range is in [-1, 1].
4236     using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4237     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4238     using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4239 
4240     for (; c <= flat_size - 16; c += 16) {
4241       F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4242           _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4243       F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4244           reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4245       F0 output0 = gemmlowp::logistic(input0);
4246       F0 output1 = gemmlowp::logistic(input1);
4247       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4248                        output0.raw().v);
4249       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4250                        output1.raw().v);
4251       input_data_ptr += 16;
4252       output_data_ptr += 16;
4253     }
4254     for (; c <= flat_size - 8; c += 8) {
4255       F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4256           _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4257       F0 output = gemmlowp::logistic(input);
4258       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4259                        output.raw().v);
4260       input_data_ptr += 8;
4261       output_data_ptr += 8;
4262     }
4263   }
4264 #endif
4265 
4266   {
4267     // F0 uses 0 integer bits, range [-1, 1].
4268     // This is the return type of math functions such as tanh, logistic,
4269     // whose range is in [-1, 1].
4270     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4271     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4272     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4273 
4274     for (; c < flat_size; ++c) {
4275       F3 input = F3::FromRaw(*input_data_ptr);
4276       F0 output = gemmlowp::logistic(input);
4277       *output_data_ptr = output.raw();
4278 
4279       ++input_data_ptr;
4280       ++output_data_ptr;
4281     }
4282   }
4283 }
4284 
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4285 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
4286                  const RuntimeShape& output_shape, float* output_data) {
4287   ruy::profiler::ScopeLabel label("Tanh");
4288   auto input_map = MapAsVector(input_data, input_shape);
4289   auto output_map = MapAsVector(output_data, output_shape);
4290   output_map.array() = input_map.array().tanh();
4291 }
4292 
4293 // Convenience version that allows, for example, generated-code calls to be
4294 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4295 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
4296                  const float* input_data, const RuntimeShape& output_shape,
4297                  float* output_data) {
4298   // Drop params: not needed.
4299   Tanh(input_shape, input_data, output_shape, output_data);
4300 }
4301 
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4302 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4303                  const int16* input_data, const RuntimeShape& output_shape,
4304                  int16* output_data) {
4305   ruy::profiler::ScopeLabel label("Tanh/Int16");
4306   const int input_left_shift = params.input_left_shift;
4307   // Support for shifts is limited until we have a parameterized version of
4308   // SaturatingRoundingMultiplyByPOT().
4309   TFLITE_DCHECK_GE(input_left_shift, 0);
4310   TFLITE_DCHECK_LE(input_left_shift, 1);
4311 
4312   const int flat_size = MatchingFlatSize(input_shape, output_shape);
4313 
4314   int c = 0;
4315   const int16* input_data_ptr = input_data;
4316   int16* output_data_ptr = output_data;
4317 #ifdef GEMMLOWP_NEON
4318   {
4319     // F0 uses 0 integer bits, range [-1, 1].
4320     // This is the return type of math functions such as tanh, logistic,
4321     // whose range is in [-1, 1].
4322     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4323     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4324     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4325 
4326     if (input_left_shift == 0) {
4327       for (; c <= flat_size - 16; c += 16) {
4328         F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4329         F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4330         F0 output0 = gemmlowp::tanh(input0);
4331         F0 output1 = gemmlowp::tanh(input1);
4332         vst1q_s16(output_data_ptr, output0.raw());
4333         vst1q_s16(output_data_ptr + 8, output1.raw());
4334 
4335         input_data_ptr += 16;
4336         output_data_ptr += 16;
4337       }
4338       for (; c <= flat_size - 8; c += 8) {
4339         F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4340         F0 output = gemmlowp::tanh(input);
4341         vst1q_s16(output_data_ptr, output.raw());
4342 
4343         input_data_ptr += 8;
4344         output_data_ptr += 8;
4345       }
4346     } else {
4347       for (; c <= flat_size - 16; c += 16) {
4348         F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4349             vld1q_s16(input_data_ptr)));
4350         F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4351             vld1q_s16(input_data_ptr + 8)));
4352         F0 output0 = gemmlowp::tanh(input0);
4353         F0 output1 = gemmlowp::tanh(input1);
4354         vst1q_s16(output_data_ptr, output0.raw());
4355         vst1q_s16(output_data_ptr + 8, output1.raw());
4356 
4357         input_data_ptr += 16;
4358         output_data_ptr += 16;
4359       }
4360       for (; c <= flat_size - 8; c += 8) {
4361         F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4362             vld1q_s16(input_data_ptr)));
4363         F0 output = gemmlowp::tanh(input);
4364         vst1q_s16(output_data_ptr, output.raw());
4365 
4366         input_data_ptr += 8;
4367         output_data_ptr += 8;
4368       }
4369     }
4370   }
4371 #endif
4372 #ifdef GEMMLOWP_SSE4
4373   {
4374     // F0 uses 0 integer bits, range [-1, 1].
4375     // This is the return type of math functions such as tanh, logistic,
4376     // whose range is in [-1, 1].
4377     using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4378     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4379     using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4380 
4381     if (input_left_shift == 0) {
4382       for (; c <= flat_size - 16; c += 16) {
4383         F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4384             _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4385         F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4386             reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4387         F0 output0 = gemmlowp::tanh(input0);
4388         F0 output1 = gemmlowp::tanh(input1);
4389         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4390                          output0.raw().v);
4391         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4392                          output1.raw().v);
4393 
4394         input_data_ptr += 16;
4395         output_data_ptr += 16;
4396       }
4397       for (; c <= flat_size - 8; c += 8) {
4398         F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4399             _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4400         F0 output = gemmlowp::tanh(input);
4401         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4402                          output.raw().v);
4403         input_data_ptr += 8;
4404         output_data_ptr += 8;
4405       }
4406     } else {
4407       for (; c <= flat_size - 16; c += 16) {
4408         F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4409             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4410                 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4411         F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4412             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4413                 reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
4414         F0 output0 = gemmlowp::tanh(input0);
4415         F0 output1 = gemmlowp::tanh(input1);
4416         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4417                          output0.raw().v);
4418         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4419                          output1.raw().v);
4420 
4421         input_data_ptr += 16;
4422         output_data_ptr += 16;
4423       }
4424       for (; c <= flat_size - 8; c += 8) {
4425         F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4426             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4427                 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4428         F0 output = gemmlowp::tanh(input);
4429         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4430                          output.raw().v);
4431         input_data_ptr += 8;
4432         output_data_ptr += 8;
4433       }
4434     }
4435   }
4436 #endif
4437 
4438   {
4439     // F0 uses 0 integer bits, range [-1, 1].
4440     // This is the return type of math functions such as tanh, logistic,
4441     // whose range is in [-1, 1].
4442     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4443     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4444     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4445 
4446     if (input_left_shift == 0) {
4447       for (; c < flat_size; ++c) {
4448         F3 input = F3::FromRaw(*input_data_ptr);
4449         F0 output = gemmlowp::tanh(input);
4450         *output_data_ptr = output.raw();
4451 
4452         ++input_data_ptr;
4453         ++output_data_ptr;
4454       }
4455     } else {
4456       for (; c < flat_size; ++c) {
4457         F3 input = F3::FromRaw(
4458             gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
4459         F0 output = gemmlowp::tanh(input);
4460         *output_data_ptr = output.raw();
4461 
4462         ++input_data_ptr;
4463         ++output_data_ptr;
4464       }
4465     }
4466   }
4467 }
4468 
4469 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)4470 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
4471                  const RuntimeShape& output_shape, DstT* output_data) {
4472   ruy::profiler::ScopeLabel label("Cast");
4473   auto input_map = MapAsVector(input_data, input_shape);
4474   auto output_map = MapAsVector(output_data, output_shape);
4475   output_map.array() = input_map.array().template cast<DstT>();
4476 }
4477 
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4478 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
4479                   const RuntimeShape& output_shape, float* output_data) {
4480   ruy::profiler::ScopeLabel label("Floor");
4481   auto input_map = MapAsVector(input_data, input_shape);
4482   auto output_map = MapAsVector(output_data, output_shape);
4483   output_map.array() = Eigen::floor(input_map.array());
4484 }
4485 
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4486 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
4487                  const RuntimeShape& output_shape, float* output_data) {
4488   ruy::profiler::ScopeLabel label("Ceil");
4489   auto input_map = MapAsVector(input_data, input_shape);
4490   auto output_map = MapAsVector(output_data, output_shape);
4491   output_map.array() = Eigen::ceil(input_map.array());
4492 }
4493 
4494 #ifdef USE_NEON
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)4495 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
4496                                  float scale, float* output_ptr) {
4497   int ic = 0;
4498   // Handle 32 input channels at a time.
4499   for (; ic <= depth - 32; ic += 32) {
4500     float32x4x2_t input[4];
4501     for (int i = 0; i < 4; i++) {
4502       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
4503       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
4504     }
4505     float32x4x2_t acc[4];
4506     for (int i = 0; i < 4; i++) {
4507       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
4508       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
4509     }
4510     for (int i = 0; i < 4; i++) {
4511       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
4512       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
4513     }
4514     for (int i = 0; i < 4; i++) {
4515       vst1q_f32(output_ptr, acc[i].val[0]);
4516       vst1q_f32(output_ptr + 4, acc[i].val[1]);
4517       output_ptr += 8;
4518     }
4519     input_ptr += 32;
4520   }
4521   // Handle 16 input channels at a time.
4522   for (; ic <= depth - 16; ic += 16) {
4523     float32x4x2_t input[2];
4524     for (int i = 0; i < 2; i++) {
4525       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
4526       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
4527     }
4528     float32x4x2_t acc[2];
4529     for (int i = 0; i < 2; i++) {
4530       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
4531       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
4532     }
4533     for (int i = 0; i < 2; i++) {
4534       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
4535       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
4536     }
4537     for (int i = 0; i < 2; i++) {
4538       vst1q_f32(output_ptr, acc[i].val[0]);
4539       vst1q_f32(output_ptr + 4, acc[i].val[1]);
4540       output_ptr += 8;
4541     }
4542     input_ptr += 16;
4543   }
4544   // Handle 8 input channels at a time.
4545   for (; ic <= depth - 8; ic += 8) {
4546     float32x4x2_t input;
4547     input.val[0] = vld1q_f32(input_ptr);
4548     input.val[1] = vld1q_f32(input_ptr + 4);
4549 
4550     float32x4x2_t acc;
4551     acc.val[0] = vld1q_f32(output_ptr);
4552     acc.val[1] = vld1q_f32(output_ptr + 4);
4553     acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
4554     acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
4555 
4556     vst1q_f32(output_ptr, acc.val[0]);
4557     vst1q_f32(output_ptr + 4, acc.val[1]);
4558 
4559     input_ptr += 8;
4560     output_ptr += 8;
4561   }
4562   // Handle 4 input channels at a time.
4563   for (; ic <= depth - 4; ic += 4) {
4564     float32x4_t input = vld1q_f32(input_ptr);
4565     float32x4_t acc = vld1q_f32(output_ptr);
4566 
4567     acc = vmlaq_n_f32(acc, input, scale);
4568     vst1q_f32(output_ptr, acc);
4569 
4570     input_ptr += 4;
4571     output_ptr += 4;
4572   }
4573   // Handle 1 input channel at a time.
4574   for (; ic < depth; ic++) {
4575     *output_ptr += *input_ptr * scale;
4576     output_ptr++;
4577     input_ptr++;
4578   }
4579 }
4580 #else
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)4581 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
4582                                  float scale, float* output_ptr) {
4583   for (int32 i = 0; i < depth; i++) {
4584     *output_ptr += *input_ptr * scale;
4585     output_ptr++;
4586     input_ptr++;
4587   }
4588 }
4589 #endif
4590 
ResizeBilinearKernel2x2(int32 x0,int32 x1,int32 y0,int32 y1,int32 x,int32 y,int32 depth,int32 batch,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4591 inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
4592                                     int32 x, int32 y, int32 depth, int32 batch,
4593                                     const RuntimeShape& input_shape,
4594                                     const float* input_data,
4595                                     const RuntimeShape& output_shape,
4596                                     float* output_data) {
4597   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4598   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4599   const int32 input_width = input_shape.Dims(2);
4600   const int32 output_width = output_shape.Dims(2);
4601 
4602   const int32 input_x_offset = (x1 - x0) * depth;
4603   const int32 input_y_offset = (y1 - y0) * depth * input_width;
4604   const int32 output_x_offset = depth;
4605   const int32 output_y_offset = depth * output_width;
4606 
4607 #ifdef USE_NEON
4608   TFLITE_DCHECK(x1 >= x0);
4609   TFLITE_DCHECK(y1 >= y0);
4610 
4611   int ic = 0;
4612   // Handle 8 input channels at a time.
4613   for (; ic <= depth - 8; ic += 8) {
4614     const float* input_ptr = nullptr;
4615 
4616     float32x4x2_t x0y0;
4617     input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
4618     x0y0.val[0] = vld1q_f32(input_ptr);
4619     x0y0.val[1] = vld1q_f32(input_ptr + 4);
4620 
4621     float32x4x2_t x1y0;
4622     input_ptr += input_x_offset;
4623     x1y0.val[0] = vld1q_f32(input_ptr);
4624     x1y0.val[1] = vld1q_f32(input_ptr + 4);
4625 
4626     float32x4x2_t x0y1;
4627     input_ptr += -input_x_offset + input_y_offset;
4628     x0y1.val[0] = vld1q_f32(input_ptr);
4629     x0y1.val[1] = vld1q_f32(input_ptr + 4);
4630 
4631     float32x4x2_t x1y1;
4632     input_ptr += input_x_offset;
4633     x1y1.val[0] = vld1q_f32(input_ptr);
4634     x1y1.val[1] = vld1q_f32(input_ptr + 4);
4635 
4636     // Top left corner.
4637     float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
4638     vst1q_f32(output_ptr, x0y0.val[0]);
4639     vst1q_f32(output_ptr + 4, x0y0.val[1]);
4640 
4641     // Top right corner.
4642     output_ptr += output_x_offset;
4643     float32x4x2_t tr;
4644     tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
4645     tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
4646     tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
4647     tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
4648 
4649     vst1q_f32(output_ptr, tr.val[0]);
4650     vst1q_f32(output_ptr + 4, tr.val[1]);
4651 
4652     // Bottom left corner.
4653     output_ptr += -output_x_offset + output_y_offset;
4654     float32x4x2_t bl;
4655     bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
4656     bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
4657     bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
4658     bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
4659     vst1q_f32(output_ptr, bl.val[0]);
4660     vst1q_f32(output_ptr + 4, bl.val[1]);
4661 
4662     // Bottom right corner.
4663     output_ptr += output_x_offset;
4664     float32x4x2_t br;
4665     br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
4666     br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
4667     br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
4668     br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
4669     br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
4670     br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
4671     vst1q_f32(output_ptr, br.val[0]);
4672     vst1q_f32(output_ptr + 4, br.val[1]);
4673   }
4674   // Handle 4 input channels at a time.
4675   for (; ic <= depth - 4; ic += 4) {
4676     const float* input_ptr =
4677         &input_data[Offset(input_shape, batch, y0, x0, ic)];
4678     float32x4_t x0y0 = vld1q_f32(input_ptr);
4679     float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
4680     float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
4681     float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
4682 
4683     // Top left corner.
4684     float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
4685     vst1q_f32(output_ptr, x0y0);
4686 
4687     // Top right corner.
4688     output_ptr += output_x_offset;
4689     float32x4_t tr = vaddq_f32(x0y0, x1y0);
4690     tr = vmulq_n_f32(tr, 0.5f);
4691     vst1q_f32(output_ptr, tr);
4692 
4693     // Bottom left corner.
4694     output_ptr += -output_x_offset + output_y_offset;
4695     float32x4_t bl = vaddq_f32(x0y0, x0y1);
4696     bl = vmulq_n_f32(bl, 0.5f);
4697     vst1q_f32(output_ptr, bl);
4698 
4699     // Bottom right corner.
4700     output_ptr += output_x_offset;
4701     float32x4_t br = vaddq_f32(x1y0, x1y1);
4702     br = vmlaq_n_f32(bl, br, 0.5f);
4703     br = vmulq_n_f32(br, 0.5f);
4704     vst1q_f32(output_ptr, br);
4705   }
4706   // Handle one input channel at a time.
4707   for (; ic < depth; ic++) {
4708     const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
4709 
4710     float x0y0 = input_data[input_offset];
4711     float x1y0 = input_data[input_offset + input_x_offset];
4712     float x0y1 = input_data[input_offset + input_y_offset];
4713     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
4714 
4715     // Top left corner.
4716     const int32 output_offset = Offset(output_shape, batch, y, x, ic);
4717     output_data[output_offset] = x0y0;
4718 
4719     // Top right corner.
4720     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
4721 
4722     // Bottom left corner.
4723     float output = (x0y0 + x0y1) / 2;
4724     output_data[output_offset + output_y_offset] = output;
4725 
4726     // Bottom right corner.
4727     output_data[output_offset + output_x_offset + output_y_offset] =
4728         (output + ((x1y0 + x1y1) / 2)) / 2;
4729   }
4730 #else
4731   for (int ch = 0; ch < depth; ch++) {
4732     const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
4733 
4734     float x0y0 = input_data[input_offset];
4735     float x1y0 = input_data[input_offset + input_x_offset];
4736     float x0y1 = input_data[input_offset + input_y_offset];
4737     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
4738 
4739     // Top left corner.
4740     const int32 output_offset = Offset(output_shape, batch, y, x, ch);
4741     output_data[output_offset] = x0y0;
4742 
4743     // Top right corner.
4744     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
4745 
4746     // Bottom left corner.
4747     float output = (x0y0 + x0y1) / 2;
4748     output_data[output_offset + output_y_offset] = output;
4749 
4750     // Bottom right corner.
4751     output_data[output_offset + output_x_offset + output_y_offset] =
4752         (output + ((x1y0 + x1y1) / 2)) / 2;
4753   }
4754 #endif
4755 }
4756 
ResizeBilinear2x2(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4757 inline void ResizeBilinear2x2(int32 batches, int32 input_height,
4758                               int32 input_width, int32 depth,
4759                               int32 output_height, int32 output_width,
4760                               const RuntimeShape& input_shape,
4761                               const float* input_data,
4762                               const RuntimeShape& output_shape,
4763                               float* output_data) {
4764   for (int b = 0; b < batches; b++) {
4765     for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
4766       for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
4767         int32 x1 = std::min(x0 + 1, input_width - 1);
4768         int32 y1 = std::min(y0 + 1, input_height - 1);
4769         ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
4770                                 input_data, output_shape, output_data);
4771       }
4772     }
4773   }
4774 }
4775 
ResizeBilinearGeneric(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,const bool half_pixel_centers)4776 inline void ResizeBilinearGeneric(
4777     int32 batches, int32 input_height, int32 input_width, int32 depth,
4778     int32 output_height, int32 output_width, float height_scale,
4779     float width_scale, const RuntimeShape& input_shape, const float* input_data,
4780     const RuntimeShape& output_shape, float* output_data,
4781     const bool half_pixel_centers) {
4782   memset(output_data, 0,
4783          batches * output_height * output_width * depth * sizeof(float));
4784 
4785   int32 output_offset = 0;
4786   for (int b = 0; b < batches; ++b) {
4787     for (int y = 0; y < output_height; ++y) {
4788       float input_y;
4789       int32 y0, y1;
4790       reference_ops::ComputeInterpolationValues(
4791           y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
4792           &y1);
4793       for (int x = 0; x < output_width; ++x) {
4794         float input_x;
4795         int32 x0, x1;
4796         reference_ops::ComputeInterpolationValues(
4797             x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
4798             &x1);
4799         float* output_ptr = &output_data[output_offset];
4800 
4801         // Run kernel on the 4 corners of the bilinear resize algorithm.
4802         int32 input_offset = Offset(input_shape, b, y0, x0, 0);
4803         float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
4804         const float* input_ptr = &input_data[input_offset];
4805         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4806 
4807         input_offset = Offset(input_shape, b, y0, x1, 0);
4808         scale = (1 - (input_y - y0)) * (input_x - x0);
4809         input_ptr = &input_data[input_offset];
4810         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4811 
4812         input_offset = Offset(input_shape, b, y1, x0, 0);
4813         scale = (input_y - y0) * (1 - (input_x - x0));
4814         input_ptr = &input_data[input_offset];
4815         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4816 
4817         input_offset = Offset(input_shape, b, y1, x1, 0);
4818         scale = (input_y - y0) * (input_x - x0);
4819         input_ptr = &input_data[input_offset];
4820         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4821 
4822         output_offset += depth;
4823       }
4824     }
4825   }
4826 }
4827 
4828 template <typename T>
ResizeBilinearGenericSmallChannel(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data,const bool half_pixel_centers)4829 inline void ResizeBilinearGenericSmallChannel(
4830     int32 batches, int32 input_height, int32 input_width, int32 depth,
4831     int32 output_height, int32 output_width, float height_scale,
4832     float width_scale, const RuntimeShape& input_shape, const T* input_data,
4833     const RuntimeShape& output_shape, T* output_data,
4834     const bool half_pixel_centers) {
4835   T* output_ptr = &output_data[0];
4836   for (int b = 0; b < batches; ++b) {
4837     for (int y = 0; y < output_height; ++y) {
4838       float input_y;
4839       int32 y0, y1;
4840       reference_ops::ComputeInterpolationValues(
4841           y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
4842           &y1);
4843       for (int x = 0; x < output_width; ++x) {
4844         float input_x;
4845         int32 x0, x1;
4846         reference_ops::ComputeInterpolationValues(
4847             x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
4848             &x1);
4849 
4850         int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
4851                                  Offset(input_shape, b, y0, x1, 0),
4852                                  Offset(input_shape, b, y1, x0, 0),
4853                                  Offset(input_shape, b, y1, x1, 0)};
4854         float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
4855                           (1 - (input_y - y0)) * (input_x - x0),
4856                           (input_y - y0) * (1 - (input_x - x0)),
4857                           (input_y - y0) * (input_x - x0)};
4858 
4859         for (int d = 0; d < depth; d++) {
4860           const T* input_ptr = &input_data[d];
4861           *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
4862                                          input_ptr[input_offset[1]] * scale[1] +
4863                                          input_ptr[input_offset[2]] * scale[2] +
4864                                          input_ptr[input_offset[3]] * scale[3]);
4865         }
4866       }
4867     }
4868   }
4869 }
4870 
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,float * output_data)4871 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
4872                            const RuntimeShape& unextended_input_shape,
4873                            const float* input_data,
4874                            const RuntimeShape& output_size_shape,
4875                            const int32* output_size_data,
4876                            const RuntimeShape& unextended_output_shape,
4877                            float* output_data) {
4878   ruy::profiler::ScopeLabel label("ResizeBilinear");
4879   // If half_pixel_centers is True, align_corners must be False.
4880   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
4881   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
4882   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
4883   const RuntimeShape input_shape =
4884       RuntimeShape::ExtendedShape(4, unextended_input_shape);
4885   const RuntimeShape output_shape =
4886       RuntimeShape::ExtendedShape(4, unextended_output_shape);
4887 
4888   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
4889   int32 input_height = input_shape.Dims(1);
4890   int32 input_width = input_shape.Dims(2);
4891   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
4892 
4893   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
4894   int32 output_height = output_size_data[0];
4895   int32 output_width = output_size_data[1];
4896 
4897   // Specialize for 2x2 upsample.
4898   if (!op_params.align_corners && !op_params.half_pixel_centers &&
4899       output_height == 2 * input_height && output_width == 2 * input_width) {
4900     ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
4901                       output_width, input_shape, input_data, output_shape,
4902                       output_data);
4903   } else {
4904     float height_scale = static_cast<float>(input_height) / output_height;
4905     float width_scale = static_cast<float>(input_width) / output_width;
4906     if (op_params.align_corners && output_height > 1) {
4907       height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
4908     }
4909     if (op_params.align_corners && output_width > 1) {
4910       width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
4911     }
4912 
4913     ResizeBilinearGeneric(batches, input_height, input_width, depth,
4914                           output_height, output_width, height_scale,
4915                           width_scale, input_shape, input_data, output_shape,
4916                           output_data, op_params.half_pixel_centers);
4917   }
4918 }
4919 
4920 // TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
4921 // or int16 arithmetic.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)4922 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
4923                            const RuntimeShape& unextended_input_shape,
4924                            const uint8* input_data,
4925                            const RuntimeShape& output_size_shape,
4926                            const int32* output_size_data,
4927                            const RuntimeShape& unextended_output_shape,
4928                            uint8* output_data) {
4929   ruy::profiler::ScopeLabel label("ResizeBilinear");
4930   // If half_pixel_centers is True, align_corners must be False.
4931   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
4932   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
4933   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
4934   const RuntimeShape input_shape =
4935       RuntimeShape::ExtendedShape(4, unextended_input_shape);
4936   const RuntimeShape output_shape =
4937       RuntimeShape::ExtendedShape(4, unextended_output_shape);
4938 
4939   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
4940   int32 input_height = input_shape.Dims(1);
4941   int32 input_width = input_shape.Dims(2);
4942   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
4943 
4944   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
4945   int32 output_height = output_size_data[0];
4946   int32 output_width = output_size_data[1];
4947 
4948   float height_scale =
4949       (op_params.align_corners && output_height > 1)
4950           ? (static_cast<float>(input_height - 1) / (output_height - 1))
4951           : (static_cast<float>(input_height) / output_height);
4952 
4953   float width_scale =
4954       (op_params.align_corners && output_width > 1)
4955           ? (static_cast<float>(input_width - 1) / (output_width - 1))
4956           : (static_cast<float>(input_width) / output_width);
4957 
4958   ResizeBilinearGenericSmallChannel<uint8>(
4959       batches, input_height, input_width, depth, output_height, output_width,
4960       height_scale, width_scale, input_shape, input_data, output_shape,
4961       output_data, op_params.half_pixel_centers);
4962 }
4963 
4964 // Helper methods for BatchToSpaceND.
4965 // `spatial_index_dim` specifies post-crop offset index in this spatial
4966 // dimension, i.e. spatial offset introduced by flattening batch to spatial
4967 // dimension minus the crop size at beginning. `block_shape_dim` is the block
4968 // size in current dimension. `input_dim` and `output_dim` are input and output
4969 // size of BatchToSpaceND operation in current dimension.
4970 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)4971 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
4972                           int input_dim, int output_dim, int* start_index,
4973                           int* end_index) {
4974   // (*start_index) * block_shape_dim is effectively rounded up to the next
4975   // multiple of block_shape_dim by the integer division.
4976   *start_index =
4977       std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4978   // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
4979   // end_index is exclusive).
4980   *end_index = std::min(
4981       input_dim,
4982       (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4983 }
4984 
4985 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)4986 inline void BatchToSpaceND(
4987     const RuntimeShape& unextended_input1_shape, const T* input1_data,
4988     const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
4989     const RuntimeShape& unextended_input3_shape, const int32* crops_data,
4990     const RuntimeShape& unextended_output_shape, T* output_data) {
4991   ruy::profiler::ScopeLabel label("BatchToSpaceND");
4992 
4993   TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
4994   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
4995   TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
4996                    unextended_output_shape.DimensionsCount());
4997 
4998   // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
4999   auto extend_shape = [](const RuntimeShape& shape) {
5000     if (shape.DimensionsCount() == 4) {
5001       return shape;
5002     }
5003     RuntimeShape new_shape(4, 1);
5004     new_shape.SetDim(0, shape.Dims(0));
5005     new_shape.SetDim(1, shape.Dims(1));
5006     new_shape.SetDim(3, shape.Dims(2));
5007     return new_shape;
5008   };
5009   const RuntimeShape input1_shape = extend_shape(unextended_input1_shape);
5010   const RuntimeShape output_shape = extend_shape(unextended_output_shape);
5011 
5012   const int output_width = output_shape.Dims(2);
5013   const int output_height = output_shape.Dims(1);
5014   const int output_batch_size = output_shape.Dims(0);
5015 
5016   const int depth = input1_shape.Dims(3);
5017   const int input_width = input1_shape.Dims(2);
5018   const int input_height = input1_shape.Dims(1);
5019   const int input_batch_size = input1_shape.Dims(0);
5020 
5021   const int block_shape_height = block_shape_data[0];
5022   const int block_shape_width =
5023       unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
5024   const int crops_top = crops_data[0];
5025   const int crops_left =
5026       unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
5027 
5028   for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
5029     const int out_batch = in_batch % output_batch_size;
5030     const int spatial_offset = in_batch / output_batch_size;
5031 
5032     int in_h_start = 0;
5033     int in_h_end = 0;
5034     // GetIndexRange ensures start and end indices are in [0, output_height).
5035     GetIndexRange(spatial_offset / block_shape_width - crops_top,
5036                   block_shape_height, input_height, output_height, &in_h_start,
5037                   &in_h_end);
5038 
5039     for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
5040       const int out_h = in_h * block_shape_height +
5041                         spatial_offset / block_shape_width - crops_top;
5042       TFLITE_DCHECK_GE(out_h, 0);
5043       TFLITE_DCHECK_LT(out_h, output_height);
5044 
5045       int in_w_start = 0;
5046       int in_w_end = 0;
5047       // GetIndexRange ensures start and end indices are in [0, output_width).
5048       GetIndexRange(spatial_offset % block_shape_width - crops_left,
5049                     block_shape_width, input_width, output_width, &in_w_start,
5050                     &in_w_end);
5051 
5052       for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
5053         const int out_w = in_w * block_shape_width +
5054                           spatial_offset % block_shape_width - crops_left;
5055         TFLITE_DCHECK_GE(out_w, 0);
5056         TFLITE_DCHECK_LT(out_w, output_width);
5057         T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
5058         const T* in =
5059             input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
5060         memcpy(out, in, depth * sizeof(T));
5061       }
5062     }
5063   }
5064 }
5065 
5066 template <typename T>
TypedMemset(void * ptr,T value,size_t num)5067 void TypedMemset(void* ptr, T value, size_t num) {
5068   // Optimization for common cases where memset() will suffice.
5069   if (value == 0 || std::is_same<T, uint8_t>::value) {
5070     memset(ptr, value, num * sizeof(T));
5071   } else {
5072     // Default implementation for cases where memset() will not preserve the
5073     // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
5074     char* pos = static_cast<char*>(ptr);
5075     for (size_t i = 0; i < num; ++i) {
5076       memcpy(pos, &value, sizeof(T));
5077       pos = pos + sizeof(T);
5078     }
5079   }
5080 }
5081 
5082 // This makes heavy use of Offset, along with conditional branches. There may be
5083 // opportunities for improvement.
5084 //
5085 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
5086 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
5087 // equivalent to a simple input1_data.  For Pad, it should point to a zero
5088 // value.
5089 //
5090 // Note that two typenames are required, so that T=P=int32 is considered a
5091 // specialization distinct from P=int32.
5092 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5093 inline void PadImpl(const tflite::PadParams& op_params,
5094                     const RuntimeShape& input_shape, const T* input_data,
5095                     const P* pad_value_ptr, const RuntimeShape& output_shape,
5096                     T* output_data) {
5097   ruy::profiler::ScopeLabel label("Pad4DSlowImpl");
5098   const RuntimeShape ext_input_shape =
5099       RuntimeShape::ExtendedShape(4, input_shape);
5100   const RuntimeShape ext_output_shape =
5101       RuntimeShape::ExtendedShape(4, output_shape);
5102   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
5103   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
5104 
5105   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
5106   // to 4 dims (yes, we are "padding the padding").
5107   std::vector<int> left_padding_copy(4, 0);
5108   const int left_padding_extend = 4 - op_params.left_padding_count;
5109   for (int i = 0; i < op_params.left_padding_count; ++i) {
5110     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
5111   }
5112   std::vector<int> right_padding_copy(4, 0);
5113   const int right_padding_extend = 4 - op_params.right_padding_count;
5114   for (int i = 0; i < op_params.right_padding_count; ++i) {
5115     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
5116   }
5117 
5118   const int output_batch = ext_output_shape.Dims(0);
5119   const int output_height = ext_output_shape.Dims(1);
5120   const int output_width = ext_output_shape.Dims(2);
5121   const int output_depth = ext_output_shape.Dims(3);
5122 
5123   const int left_b_padding = left_padding_copy[0];
5124   const int left_h_padding = left_padding_copy[1];
5125   const int left_w_padding = left_padding_copy[2];
5126   const int left_d_padding = left_padding_copy[3];
5127 
5128   const int right_b_padding = right_padding_copy[0];
5129   const int right_h_padding = right_padding_copy[1];
5130   const int right_w_padding = right_padding_copy[2];
5131   const int right_d_padding = right_padding_copy[3];
5132 
5133   const int input_depth = ext_input_shape.Dims(3);
5134   const T pad_value = *pad_value_ptr;
5135 
5136   if (left_b_padding != 0) {
5137     TypedMemset<T>(
5138         output_data, pad_value,
5139         left_b_padding * output_height * output_width * output_depth);
5140   }
5141   for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
5142        ++out_b) {
5143     if (left_h_padding != 0) {
5144       TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0),
5145                      pad_value, left_h_padding * output_width * output_depth);
5146     }
5147     for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
5148          ++out_h) {
5149       if (left_w_padding != 0) {
5150         TypedMemset<T>(
5151             output_data + Offset(ext_output_shape, out_b, out_h, 0, 0),
5152             pad_value, left_w_padding * output_depth);
5153       }
5154       for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
5155            ++out_w) {
5156         if (left_d_padding != 0) {
5157           TypedMemset<T>(
5158               output_data + Offset(ext_output_shape, out_b, out_h, out_w, 0),
5159               pad_value, left_d_padding);
5160         }
5161 
5162         T* out = output_data +
5163                  Offset(ext_output_shape, out_b, out_h, out_w, left_d_padding);
5164         const T* in = input_data +
5165                       Offset(ext_input_shape, out_b - left_b_padding,
5166                              out_h - left_h_padding, out_w - left_w_padding, 0);
5167         memcpy(out, in, input_depth * sizeof(T));
5168 
5169         if (right_d_padding != 0) {
5170           TypedMemset<T>(
5171               output_data + Offset(ext_output_shape, out_b, out_h, out_w,
5172                                    output_depth - right_d_padding),
5173               pad_value, right_d_padding);
5174         }
5175       }
5176       if (right_w_padding != 0) {
5177         TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_h,
5178                                             output_width - right_w_padding, 0),
5179                        pad_value, right_w_padding * output_depth);
5180       }
5181     }
5182     if (right_h_padding != 0) {
5183       TypedMemset<T>(
5184           output_data + Offset(ext_output_shape, out_b,
5185                                output_height - right_h_padding, 0, 0),
5186           pad_value, right_h_padding * output_width * output_depth);
5187     }
5188   }
5189   if (right_b_padding != 0) {
5190     TypedMemset<T>(
5191         output_data +
5192             Offset(ext_output_shape, output_batch - right_b_padding, 0, 0, 0),
5193         pad_value,
5194         right_b_padding * output_height * output_width * output_depth);
5195   }
5196 }
5197 
5198 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5199 inline void Pad(const tflite::PadParams& op_params,
5200                 const RuntimeShape& input_shape, const T* input_data,
5201                 const P* pad_value_ptr, const RuntimeShape& output_shape,
5202                 T* output_data) {
5203   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5204           output_data);
5205 }
5206 
5207 // The second (pad-value) input can be int32 when, say, the first is uint8.
5208 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5209 inline void Pad(const tflite::PadParams& op_params,
5210                 const RuntimeShape& input_shape, const T* input_data,
5211                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
5212                 T* output_data) {
5213   const T converted_pad_value = static_cast<T>(*pad_value_ptr);
5214   PadImpl(op_params, input_shape, input_data, &converted_pad_value,
5215           output_shape, output_data);
5216 }
5217 
5218 // This version avoids conflicting template matching.
5219 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)5220 inline void Pad(const tflite::PadParams& op_params,
5221                 const RuntimeShape& input_shape, const int32* input_data,
5222                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
5223                 int32* output_data) {
5224   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5225           output_data);
5226 }
5227 
5228 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
5229 //
5230 // This pad requires that (a) left and right paddings are in the 4D patterns
5231 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
5232 // T is uint8.
5233 //
5234 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
5235 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
5236 // equivalent to a simple input1_data.  For Pad, it should point to a zero
5237 // value.
5238 //
5239 // Note that two typenames are required, so that T=P=int32 is considered a
5240 // specialization distinct from P=int32.
5241 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5242 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
5243                                 const RuntimeShape& input_shape,
5244                                 const T* input_data, const P* pad_value_ptr,
5245                                 const RuntimeShape& output_shape,
5246                                 T* output_data) {
5247   ruy::profiler::ScopeLabel label("PadImageStyle");
5248   const RuntimeShape ext_input_shape =
5249       RuntimeShape::ExtendedShape(4, input_shape);
5250   const RuntimeShape ext_output_shape =
5251       RuntimeShape::ExtendedShape(4, output_shape);
5252   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
5253   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
5254 
5255   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
5256   // to 4 dims (yes, we are "padding the padding").
5257   std::vector<int> left_padding_copy(4, 0);
5258   const int left_padding_extend = 4 - op_params.left_padding_count;
5259   for (int i = 0; i < op_params.left_padding_count; ++i) {
5260     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
5261   }
5262   std::vector<int> right_padding_copy(4, 0);
5263   const int right_padding_extend = 4 - op_params.right_padding_count;
5264   for (int i = 0; i < op_params.right_padding_count; ++i) {
5265     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
5266   }
5267   // The following padding restrictions are contractual requirements, and
5268   // embody what it means for a padding op to be "image-style".
5269   TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
5270   TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
5271   TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
5272   TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
5273 
5274   const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
5275   const int output_height = ext_output_shape.Dims(1);
5276   const int output_width = ext_output_shape.Dims(2);
5277   const int input_height = ext_input_shape.Dims(1);
5278   const int input_width = ext_input_shape.Dims(2);
5279   const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
5280 
5281   const int left_h_padding = left_padding_copy[1];
5282   const int left_w_padding = left_padding_copy[2];
5283   const int right_h_padding = right_padding_copy[1];
5284   const int right_w_padding = right_padding_copy[2];
5285 
5286   TFLITE_DCHECK_EQ(output_height,
5287                    input_height + left_h_padding + right_h_padding);
5288   TFLITE_DCHECK_EQ(output_width,
5289                    input_width + left_w_padding + right_w_padding);
5290 
5291   const T pad_value = *pad_value_ptr;
5292   const int top_block_size = left_h_padding * output_width * depth;
5293   const size_t num_top_block_bytes = top_block_size * sizeof(T);
5294   const int bottom_block_size = right_h_padding * output_width * depth;
5295   const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
5296   const int left_blocks_size = left_w_padding * depth;
5297   const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
5298   const int right_blocks_size = right_w_padding * depth;
5299   const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
5300   const int inner_line_size = input_width * depth;
5301   const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
5302 
5303   if (input_height == 0) {
5304     memset(output_data, pad_value,
5305            num_top_block_bytes + num_bottom_block_bytes);
5306   } else {
5307     for (int i = 0; i < batch; ++i) {
5308       // For each image in the batch, apply the top padding, then iterate
5309       // through rows, then apply the bottom padding.
5310       //
5311       // By unwinding one iteration, we can combine the first left-margin
5312       // padding with the top padding, and the last right-margin padding with
5313       // the bottom padding.
5314       memset(output_data, pad_value,
5315              num_top_block_bytes + num_left_block_bytes);
5316       output_data += top_block_size + left_blocks_size;
5317       memcpy(output_data, input_data, num_inner_line_bytes);
5318       input_data += inner_line_size;
5319       output_data += inner_line_size;
5320       // One iteration unwound.
5321       // Unwinding this loop affords the opportunity to reorder the loop work
5322       // and hence combine memset() calls.
5323       //
5324       // Before unwinding:
5325       // for (int j = 0; j < input_height; ++j) {
5326       //   // Pad on left, copy central data, pad on right.
5327       //   memset(output_data, pad_value, num_left_block_bytes);
5328       //   output_data += left_blocks_size;
5329       //   memcpy(output_data, input_data, num_inner_line_bytes);
5330       //   input_data += inner_line_size;
5331       //   output_data += inner_line_size;
5332       //   memset(output_data, pad_value, num_right_block_bytes);
5333       //   output_data += right_blocks_size;
5334       // }
5335       for (int j = 1; j < input_height; ++j) {
5336         memset(output_data, pad_value,
5337                num_right_block_bytes + num_left_block_bytes);
5338         output_data += right_blocks_size + left_blocks_size;
5339         memcpy(output_data, input_data, num_inner_line_bytes);
5340         input_data += inner_line_size;
5341         output_data += inner_line_size;
5342       }
5343       memset(output_data, pad_value,
5344              num_right_block_bytes + num_bottom_block_bytes);
5345       output_data += right_blocks_size + bottom_block_size;
5346     }
5347   }
5348 }
5349 
5350 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5351 inline void PadImageStyle(const tflite::PadParams& op_params,
5352                           const RuntimeShape& input_shape, const T* input_data,
5353                           const P* pad_value_ptr,
5354                           const RuntimeShape& output_shape, T* output_data) {
5355   reference_ops::PadImageStyle(op_params, input_shape, input_data,
5356                                pad_value_ptr, output_shape, output_data);
5357 }
5358 
5359 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)5360 inline void PadImageStyle(const tflite::PadParams& op_params,
5361                           const RuntimeShape& input_shape,
5362                           const uint8* input_data, const P* pad_value_ptr,
5363                           const RuntimeShape& output_shape,
5364                           uint8* output_data) {
5365   PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
5366                       output_shape, output_data);
5367 }
5368 
5369 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)5370 inline void PadImageStyle(const tflite::PadParams& op_params,
5371                           const RuntimeShape& input_shape,
5372                           const float* input_data, const P* pad_value_ptr,
5373                           const RuntimeShape& output_shape,
5374                           float* output_data) {
5375   const float converted_pad_value = static_cast<float>(*pad_value_ptr);
5376   if (converted_pad_value == 0.0f) {
5377     PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
5378                         output_shape, output_data);
5379   } else {
5380     PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5381             output_data);
5382   }
5383 }
5384 
5385 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)5386 inline void Slice(const tflite::SliceParams& op_params,
5387                   const RuntimeShape& input_shape,
5388                   const RuntimeShape& output_shape,
5389                   SequentialTensorWriter<T>* writer) {
5390   ruy::profiler::ScopeLabel label("Slice");
5391   const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
5392   TFLITE_DCHECK_LE(op_params.begin_count, 5);
5393   TFLITE_DCHECK_LE(op_params.size_count, 5);
5394   const int begin_count = op_params.begin_count;
5395   const int size_count = op_params.size_count;
5396   // We front-pad the begin and size vectors.
5397   std::array<int, 5> start;
5398   std::array<int, 5> stop;
5399   for (int i = 0; i < 5; ++i) {
5400     int padded_i = 5 - i;
5401     start[i] =
5402         begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
5403     stop[i] =
5404         (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
5405             ? ext_shape.Dims(i)
5406             : start[i] + op_params.size[size_count - padded_i];
5407   }
5408 
5409   for (int i0 = start[0]; i0 < stop[0]; ++i0) {
5410     for (int i1 = start[1]; i1 < stop[1]; ++i1) {
5411       for (int i2 = start[2]; i2 < stop[2]; ++i2) {
5412         for (int i3 = start[3]; i3 < stop[3]; ++i3) {
5413           const int len = stop[4] - start[4];
5414           if (len > 0)
5415             writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
5416         }
5417       }
5418     }
5419   }
5420 }
5421 
5422 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)5423 inline void Slice(const tflite::SliceParams& op_params,
5424                   const RuntimeShape& input_shape, const T* input_data,
5425                   const RuntimeShape& output_shape, T* output_data) {
5426   SequentialTensorWriter<T> writer(input_data, output_data);
5427   return Slice(op_params, input_shape, output_shape, &writer);
5428 }
5429 
5430 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)5431 inline void Slice(const tflite::SliceParams& op_params,
5432                   const RuntimeShape& input_shape, const TfLiteTensor* input,
5433                   const RuntimeShape& output_shape, TfLiteTensor* output) {
5434   SequentialTensorWriter<T> writer(input, output);
5435   return Slice(op_params, input_shape, output_shape, &writer);
5436 }
5437 
5438 // Note: This implementation is only optimized for the case where the inner
5439 // stride == 1.
5440 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const RuntimeShape & unextended_output_shape,SequentialTensorWriter<T> * writer)5441 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5442                          const RuntimeShape& unextended_input_shape,
5443                          const RuntimeShape& unextended_output_shape,
5444                          SequentialTensorWriter<T>* writer) {
5445   using strided_slice::LoopCondition;
5446   using strided_slice::StartForAxis;
5447   using strided_slice::StopForAxis;
5448 
5449   ruy::profiler::ScopeLabel label("StridedSlice");
5450 
5451   // Note that the output_shape is not used herein.
5452   tflite::StridedSliceParams params_copy = op_params;
5453 
5454   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
5455   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
5456   const RuntimeShape input_shape =
5457       RuntimeShape::ExtendedShape(5, unextended_input_shape);
5458   const RuntimeShape output_shape =
5459       RuntimeShape::ExtendedShape(5, unextended_output_shape);
5460 
5461   // Reverse and pad to 5 dimensions because that is what the runtime code
5462   // requires (ie. all shapes must be 5D and are given backwards).
5463   strided_slice::StridedSlicePadIndices(&params_copy, 5);
5464 
5465   const int start_0 = StartForAxis(params_copy, input_shape, 0);
5466   const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
5467   const int start_1 = StartForAxis(params_copy, input_shape, 1);
5468   const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
5469   const int start_2 = StartForAxis(params_copy, input_shape, 2);
5470   const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
5471   const int start_3 = StartForAxis(params_copy, input_shape, 3);
5472   const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
5473   const int start_4 = StartForAxis(params_copy, input_shape, 4);
5474   const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
5475   const bool inner_stride_is_1 = params_copy.strides[4] == 1;
5476 
5477   for (int offset_0 = start_0 * input_shape.Dims(1),
5478            end_0 = stop_0 * input_shape.Dims(1),
5479            step_0 = params_copy.strides[0] * input_shape.Dims(1);
5480        !LoopCondition(offset_0, end_0, params_copy.strides[0]);
5481        offset_0 += step_0) {
5482     for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
5483              end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
5484              step_1 = params_copy.strides[1] * input_shape.Dims(2);
5485          !LoopCondition(offset_1, end_1, params_copy.strides[1]);
5486          offset_1 += step_1) {
5487       for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
5488                end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
5489                step_2 = params_copy.strides[2] * input_shape.Dims(3);
5490            !LoopCondition(offset_2, end_2, params_copy.strides[2]);
5491            offset_2 += step_2) {
5492         for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
5493                  end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
5494                  step_3 = params_copy.strides[3] * input_shape.Dims(4);
5495              !LoopCondition(offset_3, end_3, params_copy.strides[3]);
5496              offset_3 += step_3) {
5497           // When the stride is 1, the inner loop is equivalent to the
5498           // optimized slice inner loop. Otherwise, it is identical to the
5499           // strided_slice reference implementation inner loop.
5500           if (inner_stride_is_1) {
5501             const int len = stop_4 - start_4;
5502             if (len > 0) {
5503               writer->WriteN(offset_3 + start_4, len);
5504             }
5505           } else {
5506             for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
5507                  !LoopCondition(offset_4, end_4, params_copy.strides[4]);
5508                  offset_4 += params_copy.strides[4]) {
5509               writer->Write(offset_4);
5510             }
5511           }
5512         }
5513       }
5514     }
5515   }
5516 }
5517 
5518 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)5519 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5520                          const RuntimeShape& unextended_input_shape,
5521                          const T* input_data,
5522                          const RuntimeShape& unextended_output_shape,
5523                          T* output_data) {
5524   SequentialTensorWriter<T> writer(input_data, output_data);
5525   StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5526                   &writer);
5527 }
5528 
5529 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const TfLiteTensor * input,const RuntimeShape & unextended_output_shape,TfLiteTensor * output)5530 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5531                          const RuntimeShape& unextended_input_shape,
5532                          const TfLiteTensor* input,
5533                          const RuntimeShape& unextended_output_shape,
5534                          TfLiteTensor* output) {
5535   SequentialTensorWriter<T> writer(input, output);
5536   StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5537                   &writer);
5538 }
5539 
5540 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5541 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5542              const T* input2_data, const RuntimeShape& output_shape,
5543              T* output_data) {
5544   ruy::profiler::ScopeLabel label("TensorFlowMinimum");
5545   auto input1_map = MapAsVector(input1_data, input1_shape);
5546   auto output_map = MapAsVector(output_data, output_shape);
5547   auto min_value = input2_data[0];
5548   output_map.array() = input1_map.array().min(min_value);
5549 }
5550 
5551 // Convenience version that allows, for example, generated-code calls to be
5552 // the same as other binary ops.
5553 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5554 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5555                     const RuntimeShape&, const T* input2_data,
5556                     const RuntimeShape& output_shape, T* output_data) {
5557   // Drop shape of second input: not needed.
5558   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
5559 }
5560 
5561 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5562 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5563              const T* input2_data, const RuntimeShape& output_shape,
5564              T* output_data) {
5565   ruy::profiler::ScopeLabel label("TensorFlowMaximum");
5566   auto input1_map = MapAsVector(input1_data, input1_shape);
5567   auto output_map = MapAsVector(output_data, output_shape);
5568   auto max_value = input2_data[0];
5569   output_map.array() = input1_map.array().max(max_value);
5570 }
5571 
5572 // Convenience version that allows, for example, generated-code calls to be
5573 // the same as other binary ops.
5574 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5575 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5576                     const RuntimeShape&, const T* input2_data,
5577                     const RuntimeShape& output_shape, T* output_data) {
5578   // Drop shape of second input: not needed.
5579   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
5580 }
5581 
5582 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)5583 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
5584                      const RuntimeShape& input_shape, const T* input_data,
5585                      const RuntimeShape& filter_shape,
5586                      const RuntimeShape& output_shape, T* im2col_data) {
5587   ruy::profiler::ScopeLabel label("TransposeIm2col");
5588   const int stride_width = params.stride_width;
5589   const int stride_height = params.stride_height;
5590   const int pad_width = params.padding_values.width;
5591   const int pad_height = params.padding_values.height;
5592   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5593   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
5594   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
5595   TFLITE_DCHECK(im2col_data);
5596 
5597   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
5598   const int input_height = input_shape.Dims(1);
5599   const int input_width = input_shape.Dims(2);
5600   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
5601   const int filter_height = filter_shape.Dims(1);
5602   const int filter_width = filter_shape.Dims(2);
5603   const int output_height = output_shape.Dims(1);
5604   const int output_width = output_shape.Dims(2);
5605   MatchingDim(output_shape, 3, filter_shape, 0);  // output_depth
5606 
5607   // Construct the MxN sized im2col matrix.
5608   // The rows M, are sub-ordered B x H x W
5609   const RuntimeShape row_shape({1, batches, output_height, output_width});
5610   // The columns, N, are sub-ordered Kh x Kw x Din
5611   const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
5612   // Use dimensions M and N to construct dims for indexing directly into im2col
5613   const RuntimeShape im2col_shape(
5614       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
5615 
5616   // Build the im2col matrix by looping through all the input pixels,
5617   // computing their influence on the output, rather than looping through all
5618   // the output pixels. We therefore must initialize the im2col array to zero.
5619   // This is potentially inefficient because we subsequently overwrite bytes
5620   // set here. However, in practice memset is very fast and costs negligible.
5621   memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
5622 
5623   // Loop through the output batches
5624   for (int batch = 0; batch < batches; ++batch) {
5625     // Loop through input pixels one at a time.
5626     for (int in_y = 0; in_y < input_height; ++in_y) {
5627       for (int in_x = 0; in_x < input_width; ++in_x) {
5628         // Loop through the output pixels it will influence
5629         const int out_x_origin = (in_x * stride_width) - pad_width;
5630         const int out_y_origin = (in_y * stride_height) - pad_height;
5631         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
5632           const int out_y = out_y_origin + filter_y;
5633           // Is output pixel within height bounds?
5634           if ((out_y >= 0) && (out_y < output_height)) {
5635             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
5636               const int out_x = out_x_origin + filter_x;
5637               // Is output pixel within width bounds?
5638               if ((out_x >= 0) && (out_x < output_width)) {
5639                 // Copy the input elements of this pixel
5640                 T const* src =
5641                     input_data + Offset(input_shape, batch, in_y, in_x, 0);
5642                 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
5643                 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
5644                 T* dst = im2col_data +
5645                          Offset(im2col_shape, 0, 0, row_offset, col_offset);
5646                 memcpy(dst, src, input_depth * sizeof(T));
5647               }
5648             }
5649           }
5650         }
5651       }
5652     }
5653   }
5654 }
5655 
5656 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
5657 // order (height, width, depth), constructed from patches in 'col_data', which
5658 // is required to be in storage order (out_height * out_width, filter_height,
5659 // filter_width, in_depth).  Implementation by Yangqing Jia (jiayq).
5660 // Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
5661 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)5662 void Col2im(const T* col_data, const int depth, const int height,
5663             const int width, const int filter_h, const int filter_w,
5664             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
5665             const int stride_h, const int stride_w, T* im_data) {
5666   ruy::profiler::ScopeLabel label("Col2im");
5667   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
5668   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
5669   int h_pad = -pad_t;
5670   for (int h = 0; h < height_col; ++h) {
5671     int w_pad = -pad_l;
5672     for (int w = 0; w < width_col; ++w) {
5673       T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
5674       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
5675         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
5676           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
5677             // TODO(andydavis) Vectorize this loop (if compiler does not).
5678             for (int i = 0; i < depth; ++i) {
5679               im_patch_data[i] += col_data[i];
5680             }
5681           }
5682           im_patch_data += depth;
5683           col_data += depth;
5684         }
5685         // Jump over remaining number of depth.
5686         im_patch_data += depth * (width - filter_w);
5687       }
5688       w_pad += stride_w;
5689     }
5690     h_pad += stride_h;
5691   }
5692 }
5693 
5694 template <typename T>
BiasAdd(T * im_data,const T * bias_data,const int batch_size,const int height,const int width,const int depth)5695 void BiasAdd(T* im_data, const T* bias_data, const int batch_size,
5696              const int height, const int width, const int depth) {
5697   if (bias_data) {
5698     for (int n = 0; n < batch_size; ++n) {
5699       for (int h = 0; h < height; ++h) {
5700         for (int w = 0; w < width; ++w) {
5701           for (int d = 0; d < depth; ++d) {
5702             im_data[d] += bias_data[d];
5703           }
5704           im_data += depth;
5705         }
5706       }
5707     }
5708   }
5709 }
5710 
5711 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)5712 inline void TransposeConvV2(
5713     const ConvParams& params, const RuntimeShape& input_shape,
5714     const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5715     const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5716     const float* bias_data, const RuntimeShape& output_shape,
5717     float* const output_data, const RuntimeShape& col2im_shape,
5718     float* col2im_data, CpuBackendContext* cpu_backend_context) {
5719   ruy::profiler::ScopeLabel label("TransposeConvV2/float");
5720   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5721   TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5722   TFLITE_DCHECK(col2im_data);
5723   TFLITE_DCHECK(hwoi_ordered_filter_data);
5724 
5725   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5726   const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5727   const int output_height = output_shape.Dims(1);
5728   const int output_width = output_shape.Dims(2);
5729   const int output_image_size = output_height * output_width;
5730   const int input_depth =
5731       MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5732   const int output_depth =
5733       MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5734   const int input_offset = input_image_size * input_depth;
5735   const int output_offset = output_image_size * output_depth;
5736 
5737   const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5738   const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5739   const int padding_top = params.padding_values.height;
5740   const int padding_bottom =
5741       params.padding_values.height + params.padding_values.height_offset;
5742   const int padding_left = params.padding_values.width;
5743   const int padding_right =
5744       params.padding_values.width + params.padding_values.width_offset;
5745   const int stride_height = params.stride_height;
5746   const int stride_width = params.stride_width;
5747 
5748   const int hwoi_ordered_filter_total_size =
5749       filter_height * filter_width * output_depth;
5750 
5751   cpu_backend_gemm::MatrixParams<float> lhs_params;
5752   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5753   lhs_params.rows = hwoi_ordered_filter_total_size;
5754   lhs_params.cols = input_depth;
5755   float* output_data_p = output_data;
5756   std::fill_n(output_data, output_offset * batch_size, 0.0f);
5757   for (int i = 0; i < batch_size; ++i) {
5758     cpu_backend_gemm::MatrixParams<float> rhs_params;
5759     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5760     rhs_params.rows = input_depth;
5761     rhs_params.cols = input_image_size;
5762     cpu_backend_gemm::MatrixParams<float> dst_params;
5763     dst_params.order = cpu_backend_gemm::Order::kColMajor;
5764     dst_params.rows = hwoi_ordered_filter_total_size;
5765     dst_params.cols = input_image_size;
5766     cpu_backend_gemm::GemmParams<float, float> gemm_params;
5767     cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5768                            input_data + input_offset * i, dst_params,
5769                            col2im_data, gemm_params, cpu_backend_context);
5770 
5771     Col2im(col2im_data, output_depth, output_height, output_width,
5772            filter_height, filter_width, padding_top, padding_left,
5773            padding_bottom, padding_right, stride_height, stride_width,
5774            output_data_p);
5775     output_data_p += output_offset;
5776   }
5777   output_data_p = output_data;
5778   BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width,
5779           output_depth);
5780 }
5781 
Quantize(int32_t multiplier,int32_t shift,int32_t total_size,int32_t output_zp,int32_t * scratch,uint8_t * output)5782 inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
5783                      int32_t output_zp, int32_t* scratch, uint8_t* output) {
5784   ruy::profiler::ScopeLabel label("Quantize/uint8");
5785   int i = 0;
5786   const int32_t output_min = std::numeric_limits<uint8_t>::min();
5787   const int32_t output_max = std::numeric_limits<uint8_t>::max();
5788 
5789 #ifdef USE_NEON
5790   const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
5791   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
5792   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
5793 
5794   using gemmlowp::RoundingDivideByPOT;
5795   using gemmlowp::SaturatingRoundingDoublingHighMul;
5796 
5797   for (; i <= total_size - 16; i += 16) {
5798     int32x4x4_t scratch_val;
5799     scratch_val.val[0] = vld1q_s32(scratch + i);
5800     scratch_val.val[1] = vld1q_s32(scratch + i + 4);
5801     scratch_val.val[2] = vld1q_s32(scratch + i + 8);
5802     scratch_val.val[3] = vld1q_s32(scratch + i + 12);
5803 
5804     int32x4x4_t temp_val =
5805         MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
5806 
5807     temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
5808     temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
5809     temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
5810     temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
5811 
5812     temp_val.val[0] =
5813         vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
5814     temp_val.val[1] =
5815         vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
5816     temp_val.val[2] =
5817         vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
5818     temp_val.val[3] =
5819         vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
5820 
5821     const uint16x8_t result_1 =
5822         vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[0])),
5823                      vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[1])));
5824     const uint16x8_t result_2 =
5825         vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[2])),
5826                      vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[3])));
5827     const uint8x16_t result =
5828         vcombine_u8(vqmovn_u16(result_1), vqmovn_u16(result_2));
5829     vst1q_u8(output + i, result);
5830   }
5831 #endif
5832   for (; i < total_size; ++i) {
5833     int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
5834     temp += output_zp;
5835     if (temp > output_max) {
5836       temp = output_max;
5837     }
5838     if (temp < output_min) {
5839       temp = output_min;
5840     }
5841     output[i] = static_cast<uint8_t>(temp);
5842   }
5843 }
5844 
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5845 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5846                      int32_t channel_size, int32_t total_size,
5847                      int32_t output_zp, int32_t output_min, int32_t output_max,
5848                      int32_t* scratch, int8_t* output) {
5849   ruy::profiler::ScopeLabel label("Quantize/int8");
5850 
5851   // Here we're trying to quantize the raw accumulators:
5852   //        output_channels
5853   //       data data data data data
5854   // rows  data data data data data
5855   //       data data data data data
5856   //          ....
5857   //
5858   // In order to minimize the reload of the multipliers & shifts, once we load
5859   // the multipliers & shifts, we load & quantize the raw accumulators for every
5860   // row.
5861 #ifdef USE_NEON
5862   const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5863   const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5864   const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5865   const int32x4_t zeros = vdupq_n_s32(0);
5866 #endif
5867 
5868   TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5869   const int32_t rows = total_size / channel_size;
5870 
5871   int c = 0;
5872 
5873 #ifdef USE_NEON
5874   using gemmlowp::RoundingDivideByPOT;
5875   for (; c <= channel_size - 8; c += 8) {
5876     int32x4_t out_shift_1 = vld1q_s32(shift + c);
5877     int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5878     int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5879     int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5880 
5881     // Right shift will be performed as left shift with negative values.
5882     int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5883     int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5884 
5885     int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5886     int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5887     for (int n = 0; n < rows; ++n) {
5888       int loc = n * channel_size + c;
5889       int32x4_t acc_1 = vld1q_s32(scratch + loc);
5890       int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5891 
5892       // Saturating Rounding Doubling High Mul.
5893       acc_1 = vshlq_s32(acc_1, left_shift_1);
5894       acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5895       acc_2 = vshlq_s32(acc_2, left_shift_2);
5896       acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5897 
5898       // Rounding Dividing By POT.
5899       acc_1 = vrshlq_s32(acc_1, right_shift_1);
5900       acc_2 = vrshlq_s32(acc_2, right_shift_2);
5901 
5902       // Add the output offset.
5903       acc_1 = vaddq_s32(acc_1, output_offset_vec);
5904       acc_2 = vaddq_s32(acc_2, output_offset_vec);
5905 
5906       // Apply the activation function.
5907       acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5908       acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5909       acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5910       acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5911 
5912       // Saturating cast to int8 and store to destination.
5913       const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5914       const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5915       const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5916       const int8x8_t res_s8 = vqmovn_s16(res_s16);
5917       vst1_s8(output + loc, res_s8);
5918     }
5919   }
5920 
5921 #endif  // USE_NEON
5922   // Handle leftover values, one by one. This is very slow.
5923   for (; c < channel_size; c++) {
5924     for (int n = 0; n < rows; ++n) {
5925       int loc = n * channel_size + c;
5926       int32 acc = scratch[loc];
5927       acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5928       acc += output_zp;
5929       acc = std::max(acc, output_min);
5930       acc = std::min(acc, output_max);
5931       output[loc] = static_cast<int8>(acc);
5932     }
5933   }
5934 }
5935 
5936 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const uint8_t * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8_t * output_data,const RuntimeShape & col2im_shape,int32_t * col2im_data,int32_t * scratch_data,CpuBackendContext * cpu_backend_context)5937 inline void TransposeConvV2(
5938     const ConvParams& params, const RuntimeShape& input_shape,
5939     const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5940     const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5941     const int32* bias_data, const RuntimeShape& output_shape,
5942     uint8_t* output_data, const RuntimeShape& col2im_shape,
5943     int32_t* col2im_data, int32_t* scratch_data,
5944     CpuBackendContext* cpu_backend_context) {
5945   ruy::profiler::ScopeLabel label("TransposeConvV2/uint8");
5946   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5947   TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5948   TFLITE_DCHECK(col2im_data);
5949   TFLITE_DCHECK(hwoi_ordered_filter_data);
5950 
5951   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5952   const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5953   const int output_height = output_shape.Dims(1);
5954   const int output_width = output_shape.Dims(2);
5955   const int output_image_size = output_height * output_width;
5956   const int input_depth =
5957       MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5958   const int output_depth =
5959       MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5960   const int input_offset = input_image_size * input_depth;
5961   const int output_offset = output_image_size * output_depth;
5962 
5963   const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5964   const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5965   const int padding_top = params.padding_values.height;
5966   const int padding_bottom =
5967       params.padding_values.height + params.padding_values.height_offset;
5968   const int padding_left = params.padding_values.width;
5969   const int padding_right =
5970       params.padding_values.width + params.padding_values.width_offset;
5971   const int stride_height = params.stride_height;
5972   const int stride_width = params.stride_width;
5973 
5974   const int hwoi_ordered_filter_total_size =
5975       filter_height * filter_width * output_depth;
5976 
5977   cpu_backend_gemm::MatrixParams<uint8_t> lhs_params;
5978   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5979   lhs_params.rows = hwoi_ordered_filter_total_size;
5980   lhs_params.cols = input_depth;
5981   lhs_params.zero_point = -params.weights_offset;
5982 
5983   int32_t* scratch_data_p = scratch_data;
5984   std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
5985   for (int i = 0; i < batch_size; ++i) {
5986     cpu_backend_gemm::MatrixParams<uint8_t> rhs_params;
5987     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5988     rhs_params.rows = input_depth;
5989     rhs_params.cols = input_image_size;
5990     rhs_params.zero_point = -params.input_offset;
5991 
5992     cpu_backend_gemm::MatrixParams<int32_t> dst_params;
5993     dst_params.order = cpu_backend_gemm::Order::kColMajor;
5994     dst_params.rows = hwoi_ordered_filter_total_size;
5995     dst_params.cols = input_image_size;
5996 
5997     cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
5998     cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5999                            input_data + input_offset * i, dst_params,
6000                            col2im_data, gemm_params, cpu_backend_context);
6001 
6002     Col2im(col2im_data, output_depth, output_height, output_width,
6003            filter_height, filter_width, padding_top, padding_left,
6004            padding_bottom, padding_right, stride_height, stride_width,
6005            scratch_data_p);
6006 
6007     scratch_data_p += output_offset;
6008   }
6009   scratch_data_p = scratch_data;
6010   BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width,
6011           output_depth);
6012 
6013   Quantize(params.output_multiplier, params.output_shift,
6014            output_shape.FlatSize(), params.output_offset, scratch_data,
6015            output_data);
6016 }
6017 
6018 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
6019 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
6020 // reference version. Debug checks are in place to test if this occurs.
6021 // NOTE: If align_corners or half_pixel_centers is true, we use the reference
6022 // version.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)6023 inline void ResizeNearestNeighbor(
6024     const tflite::ResizeNearestNeighborParams& op_params,
6025     const RuntimeShape& unextended_input_shape, const uint8* input_data,
6026     const RuntimeShape& output_size_shape, const int32* output_size_data,
6027     const RuntimeShape& unextended_output_shape, uint8* output_data) {
6028   if (op_params.align_corners || op_params.half_pixel_centers) {
6029     // TODO(b/149823713): Add support for align_corners & half_pixel_centers in
6030     // this kernel.
6031     reference_ops::ResizeNearestNeighbor(
6032         op_params, unextended_input_shape, input_data, output_size_shape,
6033         output_size_data, unextended_output_shape, output_data);
6034     return;
6035   }
6036   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
6037   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
6038 
6039   const RuntimeShape input_shape =
6040       RuntimeShape::ExtendedShape(4, unextended_input_shape);
6041   const RuntimeShape output_shape =
6042       RuntimeShape::ExtendedShape(4, unextended_output_shape);
6043 
6044   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
6045   int32 input_height = input_shape.Dims(1);
6046   int32 input_width = input_shape.Dims(2);
6047   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
6048 
6049   // The Tensorflow version of this op allows resize on the width and height
6050   // axis only.
6051   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
6052   int32 output_height = output_size_data[0];
6053   int32 output_width = output_size_data[1];
6054 
6055   // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
6056   // error factor and to avoid zero scales. For example, with input_height = 1,
6057   // output_height = 3, the float scaling factor would be non-zero at 1/3.
6058   // With fixed-point, this is zero.
6059   int32 height_scale = (input_height << 16) / output_height + 1;
6060   int32 width_scale = (input_width << 16) / output_width + 1;
6061 
6062   const int col_offset = input_shape.Dims(3);
6063   const int row_offset = input_shape.Dims(2) * col_offset;
6064   const int batch_offset = input_shape.Dims(1) * row_offset;
6065 
6066   const uint8* input_ptr = input_data;
6067   uint8* output_ptr = output_data;
6068   for (int b = 0; b < batches; ++b) {
6069     for (int y = 0; y < output_height; ++y) {
6070       int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
6071       // Check offset calculation is the same as the reference version. See
6072       // function comment for details. We check using a non-float version of:
6073       // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
6074       //                                            / output_height)));
6075       TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
6076       TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
6077       const uint8* y_input_ptr = input_ptr + in_y * row_offset;
6078       for (int x = 0; x < output_width; ++x) {
6079         int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
6080         // Check offset calculation is the same as the reference version. See
6081         // function comment for details. We check using a non-float version of:
6082         // TFLITE_DCHECK_EQ(in_y,
6083         //                  std::floor(y * (static_cast<float>(input_width)
6084         //                                      / output_width)));
6085         TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
6086         TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
6087         const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
6088         memcpy(output_ptr, x_input_ptr, depth);
6089         output_ptr += depth;
6090       }
6091     }
6092     input_ptr += batch_offset;
6093   }
6094 }
6095 
6096 template <typename input_type, typename output_type>
Requantize(const input_type * input_data,int32_t size,int32_t effective_scale_multiplier,int32_t effective_scale_shift,int32_t input_zeropoint,int32_t output_zeropoint,output_type * output_data)6097 inline void Requantize(const input_type* input_data, int32_t size,
6098                        int32_t effective_scale_multiplier,
6099                        int32_t effective_scale_shift, int32_t input_zeropoint,
6100                        int32_t output_zeropoint, output_type* output_data) {
6101   reference_ops::Requantize(input_data, size, effective_scale_multiplier,
6102                             effective_scale_shift, input_zeropoint,
6103                             output_zeropoint, output_data);
6104 }
6105 
6106 template <>
6107 inline void Requantize<int8_t, uint8_t>(const int8_t* input_data, int32_t size,
6108                                         int32_t effective_scale_multiplier,
6109                                         int32_t effective_scale_shift,
6110                                         int32_t input_zeropoint,
6111                                         int32_t output_zeropoint,
6112                                         uint8_t* output_data) {
6113   ruy::profiler::ScopeLabel label("Requantize/Int8ToUint8");
6114 
6115   static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
6116   static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
6117 
6118   int i = 0;
6119 #ifdef USE_NEON
6120   // Constants.
6121   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6122   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6123   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6124   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6125 
6126   for (; i <= size - 16; i += 16) {
6127     const int8x16_t input_vec = vld1q_s8(input_data + i);
6128     const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
6129     const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
6130     int32x4x4_t input;
6131     input.val[0] = vmovl_s16(vget_low_s16(first_half));
6132     input.val[1] = vmovl_s16(vget_high_s16(first_half));
6133     input.val[2] = vmovl_s16(vget_low_s16(second_half));
6134     input.val[3] = vmovl_s16(vget_high_s16(second_half));
6135     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6136     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6137     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6138     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6139 
6140     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6141         input, effective_scale_multiplier, effective_scale_shift);
6142 
6143     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6144     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6145     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6146     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6147     result.val[0] =
6148         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6149     result.val[1] =
6150         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6151     result.val[2] =
6152         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6153     result.val[3] =
6154         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6155 
6156     const uint32x4_t result_val_1_unsigned =
6157         vreinterpretq_u32_s32(result.val[0]);
6158     const uint32x4_t result_val_2_unsigned =
6159         vreinterpretq_u32_s32(result.val[1]);
6160     const uint32x4_t result_val_3_unsigned =
6161         vreinterpretq_u32_s32(result.val[2]);
6162     const uint32x4_t result_val_4_unsigned =
6163         vreinterpretq_u32_s32(result.val[3]);
6164 
6165     const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
6166     const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
6167     const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
6168     const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
6169     const uint16x8_t output_first_half =
6170         vcombine_u16(narrowed_val_1, narrowed_val_2);
6171     const uint16x8_t output_second_half =
6172         vcombine_u16(narrowed_val_3, narrowed_val_4);
6173     const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
6174     const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
6175     const uint8x16_t narrowed_result =
6176         vcombine_u8(narrowed_first_half, narrowed_second_half);
6177     vst1q_u8(output_data + i, narrowed_result);
6178   }
6179 
6180 #endif
6181   for (; i < size; ++i) {
6182     const int32_t input = input_data[i] - input_zeropoint;
6183     const int32_t output =
6184         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6185                                       effective_scale_shift) +
6186         output_zeropoint;
6187     const int32_t clamped_output =
6188         std::max(std::min(output, kMaxOutput), kMinOutput);
6189     output_data[i] = static_cast<uint8_t>(clamped_output);
6190   }
6191 }
6192 
6193 template <>
6194 inline void Requantize<uint8_t, int8_t>(const uint8_t* input_data, int32_t size,
6195                                         int32_t effective_scale_multiplier,
6196                                         int32_t effective_scale_shift,
6197                                         int32_t input_zeropoint,
6198                                         int32_t output_zeropoint,
6199                                         int8_t* output_data) {
6200   ruy::profiler::ScopeLabel label("Requantize/Uint8ToInt8");
6201 
6202   static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
6203   static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
6204 
6205   int i = 0;
6206 #ifdef USE_NEON
6207   // Constants.
6208   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6209   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6210   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6211   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6212 
6213   for (; i <= size - 16; i += 16) {
6214     const uint8x16_t input_vec = vld1q_u8(input_data + i);
6215     const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
6216     const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
6217     int32x4x4_t input;
6218     input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
6219     input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
6220     input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
6221     input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
6222     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6223     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6224     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6225     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6226 
6227     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6228         input, effective_scale_multiplier, effective_scale_shift);
6229 
6230     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6231     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6232     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6233     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6234     result.val[0] =
6235         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6236     result.val[1] =
6237         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6238     result.val[2] =
6239         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6240     result.val[3] =
6241         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6242 
6243     const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
6244     const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
6245     const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
6246     const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
6247     const int16x8_t output_first_half =
6248         vcombine_s16(narrowed_val_1, narrowed_val_2);
6249     const int16x8_t output_second_half =
6250         vcombine_s16(narrowed_val_3, narrowed_val_4);
6251     const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
6252     const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
6253     const int8x16_t narrowed_result =
6254         vcombine_s8(narrowed_first_half, narrowed_second_half);
6255     vst1q_s8(output_data + i, narrowed_result);
6256   }
6257 
6258 #endif
6259   for (; i < size; ++i) {
6260     const int32_t input = input_data[i] - input_zeropoint;
6261     const int32_t output =
6262         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6263                                       effective_scale_shift) +
6264         output_zeropoint;
6265     const int32_t clamped_output =
6266         std::max(std::min(output, kMaxOutput), kMinOutput);
6267     output_data[i] = static_cast<int8_t>(clamped_output);
6268   }
6269 }
6270 
6271 template <>
6272 inline void Requantize<int8_t, int8_t>(const int8_t* input_data, int32_t size,
6273                                        int32_t effective_scale_multiplier,
6274                                        int32_t effective_scale_shift,
6275                                        int32_t input_zeropoint,
6276                                        int32_t output_zeropoint,
6277                                        int8_t* output_data) {
6278   ruy::profiler::ScopeLabel label("Requantize/Int8ToInt8");
6279 
6280   static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
6281   static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
6282 
6283   int i = 0;
6284 #ifdef USE_NEON
6285   // Constants.
6286   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6287   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6288   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6289   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6290 
6291   for (; i <= size - 16; i += 16) {
6292     const int8x16_t input_vec = vld1q_s8(input_data + i);
6293     const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
6294     const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
6295     int32x4x4_t input;
6296     input.val[0] = vmovl_s16(vget_low_s16(first_half));
6297     input.val[1] = vmovl_s16(vget_high_s16(first_half));
6298     input.val[2] = vmovl_s16(vget_low_s16(second_half));
6299     input.val[3] = vmovl_s16(vget_high_s16(second_half));
6300 
6301     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6302     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6303     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6304     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6305 
6306     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6307         input, effective_scale_multiplier, effective_scale_shift);
6308 
6309     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6310     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6311     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6312     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6313     result.val[0] =
6314         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6315     result.val[1] =
6316         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6317     result.val[2] =
6318         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6319     result.val[3] =
6320         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6321 
6322     const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
6323     const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
6324     const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
6325     const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
6326     const int16x8_t output_first_half =
6327         vcombine_s16(narrowed_val_1, narrowed_val_2);
6328     const int16x8_t output_second_half =
6329         vcombine_s16(narrowed_val_3, narrowed_val_4);
6330     const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
6331     const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
6332     const int8x16_t narrowed_result =
6333         vcombine_s8(narrowed_first_half, narrowed_second_half);
6334     vst1q_s8(output_data + i, narrowed_result);
6335   }
6336 
6337 #endif
6338   for (; i < size; ++i) {
6339     const int32_t input = input_data[i] - input_zeropoint;
6340     const int32_t output =
6341         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6342                                       effective_scale_shift) +
6343         output_zeropoint;
6344     const int32_t clamped_output =
6345         std::max(std::min(output, kMaxOutput), kMinOutput);
6346     output_data[i] = static_cast<int8_t>(clamped_output);
6347   }
6348 }
6349 
6350 template <>
6351 inline void Requantize<uint8_t, uint8_t>(
6352     const uint8_t* input_data, int32_t size, int32_t effective_scale_multiplier,
6353     int32_t effective_scale_shift, int32_t input_zeropoint,
6354     int32_t output_zeropoint, uint8_t* output_data) {
6355   ruy::profiler::ScopeLabel label("Requantize/Uint8ToUint8");
6356 
6357   static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
6358   static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
6359 
6360   int i = 0;
6361 #ifdef USE_NEON
6362   // Constants.
6363   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6364   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6365   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6366   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6367 
6368   for (; i <= size - 16; i += 16) {
6369     const uint8x16_t input_vec = vld1q_u8(input_data + i);
6370     const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
6371     const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
6372     int32x4x4_t input;
6373     input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
6374     input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
6375     input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
6376     input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
6377     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6378     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6379     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6380     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6381 
6382     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6383         input, effective_scale_multiplier, effective_scale_shift);
6384 
6385     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6386     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6387     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6388     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6389     result.val[0] =
6390         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6391     result.val[1] =
6392         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6393     result.val[2] =
6394         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6395     result.val[3] =
6396         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6397 
6398     const uint32x4_t result_val_1_unsigned =
6399         vreinterpretq_u32_s32(result.val[0]);
6400     const uint32x4_t result_val_2_unsigned =
6401         vreinterpretq_u32_s32(result.val[1]);
6402     const uint32x4_t result_val_3_unsigned =
6403         vreinterpretq_u32_s32(result.val[2]);
6404     const uint32x4_t result_val_4_unsigned =
6405         vreinterpretq_u32_s32(result.val[3]);
6406 
6407     const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
6408     const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
6409     const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
6410     const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
6411     const uint16x8_t output_first_half =
6412         vcombine_u16(narrowed_val_1, narrowed_val_2);
6413     const uint16x8_t output_second_half =
6414         vcombine_u16(narrowed_val_3, narrowed_val_4);
6415     const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
6416     const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
6417     const uint8x16_t narrowed_result =
6418         vcombine_u8(narrowed_first_half, narrowed_second_half);
6419     vst1q_u8(output_data + i, narrowed_result);
6420   }
6421 
6422 #endif
6423   for (; i < size; ++i) {
6424     const int32_t input = input_data[i] - input_zeropoint;
6425     const int32_t output =
6426         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6427                                       effective_scale_shift) +
6428         output_zeropoint;
6429     const int32_t clamped_output =
6430         std::max(std::min(output, kMaxOutput), kMinOutput);
6431     output_data[i] = static_cast<uint8_t>(clamped_output);
6432   }
6433 }
6434 
HardSwish(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)6435 inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
6436                       const RuntimeShape& output_shape, float* output_data) {
6437   ruy::profiler::ScopeLabel label("HardSwish/Float");
6438   auto size = MatchingFlatSize(input_shape, output_shape);
6439   int i = 0;
6440 #ifdef USE_NEON
6441   const float32x4_t zero = vdupq_n_f32(0.0f);
6442   const float32x4_t three = vdupq_n_f32(3.0f);
6443   const float32x4_t six = vdupq_n_f32(6.0f);
6444   const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f);
6445 
6446   for (; i <= size - 16; i += 16) {
6447     // 4x partially unrolled version of the loop below. Refer to its comments.
6448     const float32x4_t in_0 = vld1q_f32(input_data + i + 0);
6449     const float32x4_t in_1 = vld1q_f32(input_data + i + 4);
6450     const float32x4_t in_2 = vld1q_f32(input_data + i + 8);
6451     const float32x4_t in_3 = vld1q_f32(input_data + i + 12);
6452     const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth);
6453     const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth);
6454     const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth);
6455     const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth);
6456     const float32x4_t in_reluish_0 =
6457         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three)));
6458     const float32x4_t in_reluish_1 =
6459         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three)));
6460     const float32x4_t in_reluish_2 =
6461         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three)));
6462     const float32x4_t in_reluish_3 =
6463         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three)));
6464     const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0);
6465     const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1);
6466     const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2);
6467     const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3);
6468     vst1q_f32(output_data + i + 0, product_0);
6469     vst1q_f32(output_data + i + 4, product_1);
6470     vst1q_f32(output_data + i + 8, product_2);
6471     vst1q_f32(output_data + i + 12, product_3);
6472   }
6473   for (; i <= size - 4; i += 4) {
6474     // The expression to be computed is:
6475     //   out = one_sixth * in * min(six, max(zero, (in + three)))
6476     // We structure the AST to have two roughly balanced, independent branches:
6477     //  - Multiplication: in_scaled = one_sixth * in.
6478     //  - Addition and clamping: in_reluish = min(six, max(zero, (in + three))).
6479     // Then the remaining multiplication at the root of the tree.
6480     const float32x4_t in = vld1q_f32(input_data + i);
6481     const float32x4_t in_scaled = vmulq_f32(in, one_sixth);
6482     const float32x4_t in_reluish =
6483         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three)));
6484     const float32x4_t product = vmulq_f32(in_scaled, in_reluish);
6485     vst1q_f32(output_data + i, product);
6486   }
6487 #endif
6488   for (; i < size; i++) {
6489     const float in = input_data[i];
6490     output_data[i] =
6491         in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f);
6492   }
6493 }
6494 
6495 #ifdef USE_NEON
SaturateAndStore(int16x8_t src,std::uint8_t * dst)6496 inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) {
6497   // Narrow values down to 8 bit unsigned, saturating.
6498   uint8x8_t res8 = vqmovun_s16(src);
6499   // Store results to destination.
6500   vst1_u8(dst, res8);
6501 }
6502 
SaturateAndStore(int16x8_t src,std::int8_t * dst)6503 inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
6504   // Narrow values down to 8 bit unsigned, saturating.
6505   int8x8_t res8 = vqmovn_s16(src);
6506   // Store results to destination.
6507   vst1_s8(dst, res8);
6508 }
6509 #endif
6510 
6511 template <typename T>
HardSwish(const HardSwishParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6512 inline void HardSwish(const HardSwishParams& params,
6513                       const RuntimeShape& input_shape, const T* input_data,
6514                       const RuntimeShape& output_shape, T* output_data) {
6515   ruy::profiler::ScopeLabel label("HardSwish/Quantized");
6516 
6517   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6518 
6519   int i = 0;
6520   // This code heavily uses NEON saturating left shifts (vqshl*) with shift
6521   // amounts that can be zero, in which case we rely on the correct behavior
6522   // of a left shift by zero returning just its first operand unmodified.
6523   // Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is
6524   // buggy in the case of zero shift amounts, see b/137199585. That is why
6525   // this NEON code path is restricted to true ARM NEON, excluding
6526   // arm_neon_sse.h. Anyway, the arm_neon_sse.h implementation of saturating
6527   // left shifts is slow scalar code, so there may not be much benefit in
6528   // running that over just plain reference code.
6529   //
6530   // TODO(b/137199585): revisit when this is fixed.
6531 #ifdef __ARM_NEON
6532   const int16x8_t positive_reluish_multiplier_exponent_minus_one =
6533       vdupq_n_s16(std::max(0, params.reluish_multiplier_exponent - 1));
6534   const int16x8_t positive_reluish_multiplier_exponent_last_bit =
6535       vdupq_n_s16(params.reluish_multiplier_exponent > 0 ? 1 : 0);
6536   const int16x8_t negative_reluish_multiplier_exponent =
6537       vdupq_n_s16(std::min(0, params.reluish_multiplier_exponent));
6538   const int16x8_t constant_32767 = vdupq_n_s16(32767);
6539   const int16x8_t output_multiplier_exponent =
6540       vdupq_n_s16(params.output_multiplier_exponent);
6541   const int16x8_t output_zero_point = vdupq_n_s16(params.output_zero_point);
6542   // 4x unrolled version of the below NEON loop. Read that first.
6543   for (; i <= flat_size - 32; i += 32) {
6544     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6545     const int16x8x2_t input_value_0_1 =
6546         Load16AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6547     const int16x8x2_t input_value_2_3 = Load16AndSubtractZeroPoint(
6548         input_data + i + 16, params.input_zero_point);
6549     const int16x8_t input_value_on_hires_input_scale_0 =
6550         vshlq_n_s16(input_value_0_1.val[0], 7);
6551     const int16x8_t input_value_on_hires_input_scale_1 =
6552         vshlq_n_s16(input_value_0_1.val[1], 7);
6553     const int16x8_t input_value_on_hires_input_scale_2 =
6554         vshlq_n_s16(input_value_2_3.val[0], 7);
6555     const int16x8_t input_value_on_hires_input_scale_3 =
6556         vshlq_n_s16(input_value_2_3.val[1], 7);
6557     const int16x8_t input_value_on_preshift_output_scale_0 =
6558         vqrdmulhq_n_s16(input_value_on_hires_input_scale_0,
6559                         params.output_multiplier_fixedpoint_int16);
6560     const int16x8_t input_value_on_preshift_output_scale_1 =
6561         vqrdmulhq_n_s16(input_value_on_hires_input_scale_1,
6562                         params.output_multiplier_fixedpoint_int16);
6563     const int16x8_t input_value_on_preshift_output_scale_2 =
6564         vqrdmulhq_n_s16(input_value_on_hires_input_scale_2,
6565                         params.output_multiplier_fixedpoint_int16);
6566     const int16x8_t input_value_on_preshift_output_scale_3 =
6567         vqrdmulhq_n_s16(input_value_on_hires_input_scale_3,
6568                         params.output_multiplier_fixedpoint_int16);
6569     int16x8_t reluish_value_0 = input_value_on_hires_input_scale_0;
6570     int16x8_t reluish_value_1 = input_value_on_hires_input_scale_1;
6571     int16x8_t reluish_value_2 = input_value_on_hires_input_scale_2;
6572     int16x8_t reluish_value_3 = input_value_on_hires_input_scale_3;
6573     reluish_value_0 = vqshlq_s16(
6574         reluish_value_0, positive_reluish_multiplier_exponent_minus_one);
6575     reluish_value_1 = vqshlq_s16(
6576         reluish_value_1, positive_reluish_multiplier_exponent_minus_one);
6577     reluish_value_2 = vqshlq_s16(
6578         reluish_value_2, positive_reluish_multiplier_exponent_minus_one);
6579     reluish_value_3 = vqshlq_s16(
6580         reluish_value_3, positive_reluish_multiplier_exponent_minus_one);
6581     reluish_value_0 = vqrdmulhq_n_s16(
6582         reluish_value_0, params.reluish_multiplier_fixedpoint_int16);
6583     reluish_value_1 = vqrdmulhq_n_s16(
6584         reluish_value_1, params.reluish_multiplier_fixedpoint_int16);
6585     reluish_value_2 = vqrdmulhq_n_s16(
6586         reluish_value_2, params.reluish_multiplier_fixedpoint_int16);
6587     reluish_value_3 = vqrdmulhq_n_s16(
6588         reluish_value_3, params.reluish_multiplier_fixedpoint_int16);
6589     reluish_value_0 = vqshlq_s16(reluish_value_0,
6590                                  positive_reluish_multiplier_exponent_last_bit);
6591     reluish_value_1 = vqshlq_s16(reluish_value_1,
6592                                  positive_reluish_multiplier_exponent_last_bit);
6593     reluish_value_2 = vqshlq_s16(reluish_value_2,
6594                                  positive_reluish_multiplier_exponent_last_bit);
6595     reluish_value_3 = vqshlq_s16(reluish_value_3,
6596                                  positive_reluish_multiplier_exponent_last_bit);
6597     reluish_value_0 =
6598         vrshlq_s16(reluish_value_0, negative_reluish_multiplier_exponent);
6599     reluish_value_1 =
6600         vrshlq_s16(reluish_value_1, negative_reluish_multiplier_exponent);
6601     reluish_value_2 =
6602         vrshlq_s16(reluish_value_2, negative_reluish_multiplier_exponent);
6603     reluish_value_3 =
6604         vrshlq_s16(reluish_value_3, negative_reluish_multiplier_exponent);
6605     reluish_value_0 = vrhaddq_s16(reluish_value_0, constant_32767);
6606     reluish_value_1 = vrhaddq_s16(reluish_value_1, constant_32767);
6607     reluish_value_2 = vrhaddq_s16(reluish_value_2, constant_32767);
6608     reluish_value_3 = vrhaddq_s16(reluish_value_3, constant_32767);
6609     const int16x8_t preshift_output_value_0 =
6610         vqdmulhq_s16(reluish_value_0, input_value_on_preshift_output_scale_0);
6611     const int16x8_t preshift_output_value_1 =
6612         vqdmulhq_s16(reluish_value_1, input_value_on_preshift_output_scale_1);
6613     const int16x8_t preshift_output_value_2 =
6614         vqdmulhq_s16(reluish_value_2, input_value_on_preshift_output_scale_2);
6615     const int16x8_t preshift_output_value_3 =
6616         vqdmulhq_s16(reluish_value_3, input_value_on_preshift_output_scale_3);
6617     int16x8_t output_value_0 =
6618         vrshlq_s16(preshift_output_value_0, output_multiplier_exponent);
6619     int16x8_t output_value_1 =
6620         vrshlq_s16(preshift_output_value_1, output_multiplier_exponent);
6621     int16x8_t output_value_2 =
6622         vrshlq_s16(preshift_output_value_2, output_multiplier_exponent);
6623     int16x8_t output_value_3 =
6624         vrshlq_s16(preshift_output_value_3, output_multiplier_exponent);
6625     output_value_0 = vaddq_s16(output_value_0, output_zero_point);
6626     output_value_1 = vaddq_s16(output_value_1, output_zero_point);
6627     output_value_2 = vaddq_s16(output_value_2, output_zero_point);
6628     output_value_3 = vaddq_s16(output_value_3, output_zero_point);
6629     SaturateAndStore(output_value_0, output_data + i);
6630     SaturateAndStore(output_value_1, output_data + i + 8);
6631     SaturateAndStore(output_value_2, output_data + i + 16);
6632     SaturateAndStore(output_value_3, output_data + i + 24);
6633   }
6634   // NEON version of reference_ops::HardSwish. Read that first.
6635   for (; i <= flat_size - 8; i += 8) {
6636     using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
6637     const int16x8_t input_value =
6638         Load8AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6639     const int16x8_t input_value_on_hires_input_scale =
6640         vshlq_n_s16(input_value, 7);
6641     const int16x8_t input_value_on_preshift_output_scale =
6642         vqrdmulhq_n_s16(input_value_on_hires_input_scale,
6643                         params.output_multiplier_fixedpoint_int16);
6644     int16x8_t reluish_value = input_value_on_hires_input_scale;
6645     reluish_value = vqshlq_s16(reluish_value,
6646                                positive_reluish_multiplier_exponent_minus_one);
6647     reluish_value = vqrdmulhq_n_s16(reluish_value,
6648                                     params.reluish_multiplier_fixedpoint_int16);
6649     reluish_value = vqshlq_s16(reluish_value,
6650                                positive_reluish_multiplier_exponent_last_bit);
6651     reluish_value =
6652         vrshlq_s16(reluish_value, negative_reluish_multiplier_exponent);
6653     reluish_value = vrhaddq_s16(reluish_value, constant_32767);
6654     const int16x8_t preshift_output_value =
6655         vqdmulhq_s16(reluish_value, input_value_on_preshift_output_scale);
6656     int16x8_t output_value =
6657         vrshlq_s16(preshift_output_value, output_multiplier_exponent);
6658     output_value = vaddq_s16(output_value, output_zero_point);
6659     SaturateAndStore(output_value, output_data + i);
6660   }
6661 #endif
6662   // TODO(b/137208495): revisit when unit tests cover reference code.
6663   // Fall back to reference_ops::HardSwish. In general we have preferred
6664   // to duplicate such scalar code rather than call reference code to handle
6665   // leftovers, thinking that code duplication was not a big concern.
6666   // However, most of our unit tests happen to test only optimized code,
6667   // and the quantized HardSwish implementation is nontrivial enough that
6668   // I really want test coverage for the reference code.
6669   if (i < flat_size) {
6670     const RuntimeShape leftover_shape{flat_size - i};
6671     reference_ops::HardSwish(params, leftover_shape, input_data + i,
6672                              leftover_shape, output_data + i);
6673   }
6674 }
6675 
6676 template <typename T>
IntegerExponentPow(const ArithmeticParams & params,const RuntimeShape & unextended_base_shape,const T * base_data,const int exponent,const RuntimeShape & unextended_output_shape,T * output_data)6677 inline void IntegerExponentPow(const ArithmeticParams& params,
6678                                const RuntimeShape& unextended_base_shape,
6679                                const T* base_data, const int exponent,
6680                                const RuntimeShape& unextended_output_shape,
6681                                T* output_data) {
6682   TFLITE_DCHECK_GE(exponent, 1);
6683   if (exponent == 1) {
6684     // copy data over.
6685     std::memcpy(output_data, base_data,
6686                 unextended_base_shape.FlatSize() * sizeof(T));
6687   } else {
6688     IntegerExponentPow(params, unextended_base_shape, base_data, exponent / 2,
6689                        unextended_output_shape, output_data);
6690     Mul(params, unextended_base_shape, output_data, unextended_base_shape,
6691         output_data, unextended_output_shape, output_data);
6692     if (exponent % 2 == 1) {
6693       Mul(params, unextended_base_shape, base_data, unextended_base_shape,
6694           output_data, unextended_output_shape, output_data);
6695     }
6696   }
6697 }
6698 
6699 template <typename T>
BroadcastPow4D(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)6700 inline void BroadcastPow4D(const RuntimeShape& unextended_input1_shape,
6701                            const T* input1_data,
6702                            const RuntimeShape& unextended_input2_shape,
6703                            const T* input2_data,
6704                            const RuntimeShape& unextended_output_shape,
6705                            T* output_data) {
6706   ruy::profiler::ScopeLabel label("PowBroadcast");
6707 
6708   if (unextended_input2_shape.FlatSize() == 1) {
6709     static const float epsilon = 1e-5;
6710     const T exponent = input2_data[0];
6711     const int int_exponent = static_cast<int>(std::round(exponent));
6712     if ((std::abs(input2_data[0] - int_exponent) < epsilon) &&
6713         (int_exponent >= 1)) {
6714       ArithmeticParams params;
6715       if (std::is_same<T, float>::value) {
6716         params.float_activation_max = std::numeric_limits<float>::max();
6717         params.float_activation_min = std::numeric_limits<float>::lowest();
6718       } else if (std::is_same<T, int>::value) {
6719         params.quantized_activation_max = std::numeric_limits<int>::max();
6720         params.quantized_activation_min = std::numeric_limits<int>::lowest();
6721       }
6722       IntegerExponentPow(params, unextended_input1_shape, input1_data,
6723                          int_exponent, unextended_output_shape, output_data);
6724       return;
6725     }
6726   }
6727   reference_ops::BroadcastPow4DSlow(unextended_input1_shape, input1_data,
6728                                     unextended_input2_shape, input2_data,
6729                                     unextended_output_shape, output_data);
6730 }
6731 
6732 #ifdef USE_NEON
6733 
ScaleWithNewZeroPoint(const int32x4_t input,const float32x4_t scale_dup,const float32x4_t zero_times_scale_dup,float32x4_t * output)6734 inline void ScaleWithNewZeroPoint(const int32x4_t input,
6735                                   const float32x4_t scale_dup,
6736                                   const float32x4_t zero_times_scale_dup,
6737                                   float32x4_t* output) {
6738 #ifdef __ARM_FEATURE_FMA
6739   *output = vfmaq_f32(zero_times_scale_dup, vcvtq_f32_s32(input), scale_dup);
6740 #else
6741   *output = vaddq_f32(vmulq_f32(vcvtq_f32_s32(input), scale_dup),
6742                       zero_times_scale_dup);
6743 #endif
6744 }
6745 
6746 #endif  // USE_NEON
6747 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & output_shape,float * output_data)6748 inline void Dequantize(const tflite::DequantizationParams& op_params,
6749                        const RuntimeShape& input_shape,
6750                        const uint8_t* input_data,
6751                        const RuntimeShape& output_shape, float* output_data) {
6752   ruy::profiler::ScopeLabel label("Dequantize/Uint8");
6753   const int32 zero_point = op_params.zero_point;
6754   const double scale = op_params.scale;
6755   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6756 
6757   int i = 0;
6758 #ifdef USE_NEON
6759   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6760   const float32x4_t zero_times_scale_dup =
6761       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6762   for (; i <= flat_size - 8; i += 8) {
6763     const uint8x8_t input_u8 = vld1_u8(input_data + i);
6764     const uint16x8_t input_u16 = vmovl_u8(input_u8);
6765     const int16x8_t input_s16 = vreinterpretq_s16_u16(input_u16);
6766     const int16x4_t input_s16_low = vget_low_s16(input_s16);
6767     const int16x4_t input_s16_high = vget_high_s16(input_s16);
6768     const int32x4_t val_low = vmovl_s16(input_s16_low);
6769     const int32x4_t val_high = vmovl_s16(input_s16_high);
6770 
6771     float32x4_t result_low, result_high;
6772     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6773                           &result_low);
6774     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6775                           &result_high);
6776 
6777     vst1q_f32(output_data + i, result_low);
6778     vst1q_f32(output_data + i + 4, result_high);
6779   }
6780 #endif  // NEON
6781   for (; i < flat_size; ++i) {
6782     const int32 val = input_data[i];
6783     const float result = static_cast<float>(scale * (val - zero_point));
6784     output_data[i] = result;
6785   }
6786 }
6787 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & output_shape,float * output_data)6788 inline void Dequantize(const tflite::DequantizationParams& op_params,
6789                        const RuntimeShape& input_shape,
6790                        const int8_t* input_data,
6791                        const RuntimeShape& output_shape, float* output_data) {
6792   ruy::profiler::ScopeLabel label("Dequantize/Int8");
6793   const int32 zero_point = op_params.zero_point;
6794   const double scale = op_params.scale;
6795   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6796 
6797   int i = 0;
6798 #ifdef USE_NEON
6799   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6800   const float32x4_t zero_times_scale_dup =
6801       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6802   for (; i <= flat_size - 8; i += 8) {
6803     const int8x8_t input_s8 = vld1_s8(input_data + i);
6804     const int16x8_t input_s16 = vmovl_s8(input_s8);
6805     const int16x4_t input_s16_low = vget_low_s16(input_s16);
6806     const int16x4_t input_s16_high = vget_high_s16(input_s16);
6807     const int32x4_t val_low = vmovl_s16(input_s16_low);
6808     const int32x4_t val_high = vmovl_s16(input_s16_high);
6809 
6810     float32x4_t result_low, result_high;
6811     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6812                           &result_low);
6813     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6814                           &result_high);
6815 
6816     vst1q_f32(output_data + i, result_low);
6817     vst1q_f32(output_data + i + 4, result_high);
6818   }
6819 #endif  // NEON
6820   for (; i < flat_size; ++i) {
6821     const int32 val = input_data[i];
6822     const float result = static_cast<float>(scale * (val - zero_point));
6823     output_data[i] = result;
6824   }
6825 }
6826 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,float * output_data)6827 inline void Dequantize(const tflite::DequantizationParams& op_params,
6828                        const RuntimeShape& input_shape,
6829                        const int16_t* input_data,
6830                        const RuntimeShape& output_shape, float* output_data) {
6831   ruy::profiler::ScopeLabel label("Dequantize/Int16");
6832   const int32 zero_point = op_params.zero_point;
6833   const double scale = op_params.scale;
6834   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6835 
6836   int i = 0;
6837 #ifdef USE_NEON
6838   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6839   const float32x4_t zero_times_scale_dup =
6840       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6841   for (; i <= flat_size - 8; i += 8) {
6842     const int16x4_t input_s16_low = vld1_s16(input_data + i);
6843     const int16x4_t input_s16_high = vld1_s16(input_data + i + 4);
6844     const int32x4_t val_low = vmovl_s16(input_s16_low);
6845     const int32x4_t val_high = vmovl_s16(input_s16_high);
6846 
6847     float32x4_t result_low, result_high;
6848     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6849                           &result_low);
6850     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6851                           &result_high);
6852 
6853     vst1q_f32(output_data + i, result_low);
6854     vst1q_f32(output_data + i + 4, result_high);
6855   }
6856 #endif  // NEON
6857   for (; i < flat_size; ++i) {
6858     const int32 val = input_data[i];
6859     const float result = static_cast<float>(scale * (val - zero_point));
6860     output_data[i] = result;
6861   }
6862 }
6863 
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)6864 inline void Dequantize(const RuntimeShape& input_shape,
6865                        const Eigen::half* input_data,
6866                        const RuntimeShape& output_shape, float* output_data) {
6867   reference_ops::Dequantize(input_shape, input_data, output_shape, output_data);
6868 }
6869 
6870 template <typename T>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,T * output_data)6871 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6872                            const RuntimeShape& input_shape,
6873                            const float* input_data,
6874                            const RuntimeShape& output_shape, T* output_data) {
6875   reference_ops::AffineQuantize(op_params, input_shape, input_data,
6876                                 output_shape, output_data);
6877 }
6878 
6879 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int8_t * output_data)6880 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6881                            const RuntimeShape& input_shape,
6882                            const float* input_data,
6883                            const RuntimeShape& output_shape,
6884                            int8_t* output_data) {
6885   ruy::profiler::ScopeLabel label("Quantize/Int8");
6886   const int32 zero_point = op_params.zero_point;
6887   const double scale = static_cast<double>(op_params.scale);
6888   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6889   static constexpr int32 min_val = std::numeric_limits<int8_t>::min();
6890   static constexpr int32 max_val = std::numeric_limits<int8_t>::max();
6891 
6892   int i = 0;
6893 #ifdef USE_NEON
6894   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6895   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6896   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6897   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6898 
6899   for (; i <= flat_size - 8; i += 8) {
6900     const float* src_data_ptr = input_data + i;
6901     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6902     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6903 
6904     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6905     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6906 
6907     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6908     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6909 
6910     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6911     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6912 
6913     // Clamp the values to fit the target type's range.
6914     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6915     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6916     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6917     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6918 
6919     const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6920     const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6921     const int16x8_t combined_val = vcombine_s16(narrowed_val_0, narrowed_val_1);
6922     const int8x8_t combined_val_narrowed = vmovn_s16(combined_val);
6923     vst1_s8(output_data + i, combined_val_narrowed);
6924   }
6925 #endif  // NEON
6926 
6927   for (; i < flat_size; ++i) {
6928     const float val = input_data[i];
6929     const int32 unclamped =
6930         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6931     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6932     output_data[i] = clamped;
6933   }
6934 }
6935 
6936 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,uint8_t * output_data)6937 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6938                            const RuntimeShape& input_shape,
6939                            const float* input_data,
6940                            const RuntimeShape& output_shape,
6941                            uint8_t* output_data) {
6942   ruy::profiler::ScopeLabel label("Quantize/Uint8");
6943   const int32 zero_point = op_params.zero_point;
6944   const double scale = static_cast<double>(op_params.scale);
6945   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6946   static constexpr int32 min_val = std::numeric_limits<uint8_t>::min();
6947   static constexpr int32 max_val = std::numeric_limits<uint8_t>::max();
6948 
6949   int i = 0;
6950 #ifdef USE_NEON
6951   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6952   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6953   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6954   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6955 
6956   for (; i <= flat_size - 8; i += 8) {
6957     const float* src_data_ptr = input_data + i;
6958     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6959     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6960 
6961     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6962     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6963 
6964     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6965     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6966 
6967     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6968     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6969 
6970     // Clamp the values to fit the target type's range.
6971     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6972     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6973     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6974     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6975 
6976     const uint16x4_t narrowed_val_0 = vqmovun_s32(casted_val_0);
6977     const uint16x4_t narrowed_val_1 = vqmovun_s32(casted_val_1);
6978     const uint16x8_t combined_val =
6979         vcombine_u16(narrowed_val_0, narrowed_val_1);
6980     const uint8x8_t combined_val_narrowed = vmovn_u16(combined_val);
6981     vst1_u8(output_data + i, combined_val_narrowed);
6982   }
6983 #endif  // NEON
6984 
6985   for (; i < flat_size; ++i) {
6986     const float val = input_data[i];
6987     const int32 unclamped =
6988         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6989     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6990     output_data[i] = clamped;
6991   }
6992 }
6993 
6994 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int16_t * output_data)6995 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6996                            const RuntimeShape& input_shape,
6997                            const float* input_data,
6998                            const RuntimeShape& output_shape,
6999                            int16_t* output_data) {
7000   ruy::profiler::ScopeLabel label("Quantize/Int16");
7001   const int32 zero_point = op_params.zero_point;
7002   const double scale = static_cast<double>(op_params.scale);
7003   const int flat_size = MatchingFlatSize(input_shape, output_shape);
7004   static constexpr int32 min_val = std::numeric_limits<int16_t>::min();
7005   static constexpr int32 max_val = std::numeric_limits<int16_t>::max();
7006 
7007   int i = 0;
7008 #ifdef USE_NEON
7009   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
7010   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
7011   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
7012   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
7013 
7014   for (; i <= flat_size - 8; i += 8) {
7015     const float* src_data_ptr = input_data + i;
7016     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
7017     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
7018 
7019     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
7020     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
7021 
7022     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
7023     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
7024 
7025     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
7026     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
7027 
7028     // Clamp the values to fit the target type's range.
7029     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
7030     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
7031     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
7032     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
7033 
7034     const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
7035     const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
7036     vst1_s16(output_data + i, narrowed_val_0);
7037     vst1_s16(output_data + i + 4, narrowed_val_1);
7038   }
7039 #endif  // NEON
7040 
7041   for (; i < flat_size; ++i) {
7042     const float val = input_data[i];
7043     const int32 unclamped =
7044         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
7045     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
7046     output_data[i] = clamped;
7047   }
7048 }
7049 
7050 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7051 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7052 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7053 #ifdef GEMMLOWP_NEON
7054 
SaturatingRounding(int16x8_t input_val_0,int16x8_t input_val_1,int16x8_t input_val_2,int16x8_t input_val_3,int input_left_shift,int input_multiplier)7055 inline int16x8x4_t SaturatingRounding(
7056     int16x8_t input_val_0, int16x8_t input_val_1, int16x8_t input_val_2,
7057     int16x8_t input_val_3, int input_left_shift, int input_multiplier) {
7058   // This performs what is expressed in the scalar code as
7059   // const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7060   //      static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7061   //      static_cast<int16>(input_multiplier));
7062   const int16x8_t left_shift_dup = vdupq_n_s16(input_left_shift);
7063   const int16x8_t input_val_shifted_0 = vshlq_s16(input_val_0, left_shift_dup);
7064   const int16x8_t input_val_shifted_1 = vshlq_s16(input_val_1, left_shift_dup);
7065   const int16x8_t input_val_shifted_2 = vshlq_s16(input_val_2, left_shift_dup);
7066   const int16x8_t input_val_shifted_3 = vshlq_s16(input_val_3, left_shift_dup);
7067   int16x8x4_t result;
7068   result.val[0] = vqrdmulhq_n_s16(input_val_shifted_0, input_multiplier);
7069   result.val[1] = vqrdmulhq_n_s16(input_val_shifted_1, input_multiplier);
7070   result.val[2] = vqrdmulhq_n_s16(input_val_shifted_2, input_multiplier);
7071   result.val[3] = vqrdmulhq_n_s16(input_val_shifted_3, input_multiplier);
7072   return result;
7073 }
7074 
7075 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
7076 // considering 7 digits under zero.
FixedPoint4Logistic(int16x8x4_t input_val)7077 inline int16x8x4_t FixedPoint4Logistic(int16x8x4_t input_val) {
7078   // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
7079   using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
7080   using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
7081   const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
7082   const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
7083   const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
7084   const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
7085 
7086   // TODO(b/134622898) Implement a low accuracy version of logistic. In this
7087   // method, gemmlowp::tanh spends about 80% of the execution times. The
7088   // current implementation is rougly 12-bit accurate in the 16-bit fixed
7089   // point case. Until reaching to error bounds, there are rooms for
7090   // improvements.
7091   const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
7092   const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
7093   const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
7094   const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
7095 
7096   // Divide by 2^7 as in the scalar code
7097   int16x8x4_t result;
7098   result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 7);
7099   result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 7);
7100   result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 7);
7101   result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 7);
7102   return result;
7103 }
7104 
7105 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
7106 // considering 11 digits under zero at least.
FixedPoint4Tanh(int16x8x4_t input_val)7107 inline int16x8x4_t FixedPoint4Tanh(int16x8x4_t input_val) {
7108   // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
7109   using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
7110   using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
7111   const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
7112   const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
7113   const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
7114   const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
7115 
7116   // TODO(b/134622898) Implement a low accuracy version of logistic. In this
7117   // method, gemmlowp::tanh spends about 80% of the execution times. The
7118   // current implementation is rougly 12-bit accurate in the 16-bit fixed
7119   // point case. Until reaching to error bounds, there are rooms for
7120   // improvements.
7121   const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
7122   const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
7123   const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
7124   const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
7125 
7126   // Divide by 2^7 as in the scalar code
7127   int16x8x4_t result;
7128   result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 8);
7129   result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 8);
7130   result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 8);
7131   result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 8);
7132   return result;
7133 }
7134 
CalculateUnsignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)7135 inline uint8x16x2_t CalculateUnsignedClampingWithRangeBitMasks(
7136     int16x8x2_t input_val, int16x8_t range_radius_dup,
7137     int16x8_t neg_range_radius_dup) {
7138   const uint16x8_t mask_rightclamp_0 =
7139       vcgtq_s16(input_val.val[0], range_radius_dup);
7140   const uint16x8_t mask_rightclamp_1 =
7141       vcgtq_s16(input_val.val[1], range_radius_dup);
7142 
7143   const uint16x8_t mask_leftclamp_0 =
7144       vcgeq_s16(input_val.val[0], neg_range_radius_dup);
7145   const uint16x8_t mask_leftclamp_1 =
7146       vcgeq_s16(input_val.val[1], neg_range_radius_dup);
7147 
7148   uint8x16x2_t result;
7149   result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
7150                               vshrn_n_u16(mask_leftclamp_1, 8));
7151   result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
7152                               vshrn_n_u16(mask_rightclamp_1, 8));
7153   return result;
7154 }
7155 
CalculateSignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)7156 inline uint8x16x2_t CalculateSignedClampingWithRangeBitMasks(
7157     int16x8x2_t input_val, int16x8_t range_radius_dup,
7158     int16x8_t neg_range_radius_dup) {
7159   const uint16x8_t mask_rightclamp_0 =
7160       vcgtq_s16(input_val.val[0], range_radius_dup);
7161   const uint16x8_t mask_rightclamp_1 =
7162       vcgtq_s16(input_val.val[1], range_radius_dup);
7163 
7164   const uint16x8_t mask_leftclamp_0 =
7165       vcltq_s16(input_val.val[0], neg_range_radius_dup);
7166   const uint16x8_t mask_leftclamp_1 =
7167       vcltq_s16(input_val.val[1], neg_range_radius_dup);
7168 
7169   uint8x16x2_t result;
7170   result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
7171                               vshrn_n_u16(mask_leftclamp_1, 8));
7172   result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
7173                               vshrn_n_u16(mask_rightclamp_1, 8));
7174   return result;
7175 }
7176 
ClampWithRangeAndStore(uint8_t * output_dst,uint8x16_t input_val,uint8x16x2_t masks_clamp)7177 inline void ClampWithRangeAndStore(uint8_t* output_dst, uint8x16_t input_val,
7178                                    uint8x16x2_t masks_clamp) {
7179   // Store back to memory
7180   vst1q_u8(output_dst, vandq_u8(vorrq_u8(input_val, masks_clamp.val[1]),
7181                                 masks_clamp.val[0]));
7182 }
7183 
ClampWithRangeAndStore(int8_t * output_dst,int8x16_t input_val,uint8x16x2_t masks_clamp)7184 inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val,
7185                                    uint8x16x2_t masks_clamp) {
7186   static const int8x16_t max_dup = vdupq_n_s8(127);
7187   static const int8x16_t min_dup = vdupq_n_s8(-128);
7188   // Store back to memory
7189   vst1q_s8(output_dst,
7190            vbslq_s8(masks_clamp.val[1], max_dup,
7191                     vbslq_s8(masks_clamp.val[0], min_dup, input_val)));
7192 }
7193 
7194 #endif  // GEMMLOWP_NEON
7195 
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)7196 inline void Tanh16bitPrecision(const TanhParams& params,
7197                                const RuntimeShape& input_shape,
7198                                const uint8* input_data,
7199                                const RuntimeShape& output_shape,
7200                                uint8* output_data) {
7201   // Note that this is almost the exact same code as in Logistic().
7202   ruy::profiler::ScopeLabel label("Tanh/Uint8");
7203   const int32 input_zero_point = params.input_zero_point;
7204   const int32 input_range_radius = params.input_range_radius;
7205   const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
7206   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7207   const int size = MatchingFlatSize(input_shape, output_shape);
7208 
7209   int c = 0;
7210   int16_t output_zero_point = 128;
7211 
7212 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7213 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7214 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7215 #ifdef GEMMLOWP_NEON
7216   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7217   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7218   const int16x8_t output_zero_point_s16 = vdupq_n_s16(output_zero_point);
7219 
7220   // Handle 32 values at a time
7221   for (; c <= size - 32; c += 32) {
7222     // Read input uint8 values, cast to int16 and subtract input_zero_point
7223     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7224     const int16x8x2_t input_val_centered_0_1 =
7225         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7226     const int16x8x2_t input_val_centered_2_3 =
7227         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7228 
7229     // Prepare the bit masks that we will use at the end to implement the logic
7230     // that was expressed in the scalar code with branching:
7231     //   if (input_val_centered < -input_range_radius) {
7232     //     output_val = 0;
7233     //   } else if (input_val_centered > input_range_radius) {
7234     //     output_val = 255;
7235     //   } else {
7236     //     ...
7237     uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
7238         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7239     uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
7240         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7241 
7242     int16x8x4_t input_val_rescaled = SaturatingRounding(
7243         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7244         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7245         input_left_shift, input_multiplier);
7246 
7247     int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
7248 
7249     // Add the output zero point
7250     output_val_s16.val[0] =
7251         vaddq_s16(output_val_s16.val[0], output_zero_point_s16);
7252     output_val_s16.val[1] =
7253         vaddq_s16(output_val_s16.val[1], output_zero_point_s16);
7254     output_val_s16.val[2] =
7255         vaddq_s16(output_val_s16.val[2], output_zero_point_s16);
7256     output_val_s16.val[3] =
7257         vaddq_s16(output_val_s16.val[3], output_zero_point_s16);
7258 
7259     // Cast output values to uint8, saturating
7260     uint8x16_t output_val_u8_0_1 = vcombine_u8(
7261         vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
7262     uint8x16_t output_val_u8_2_3 = vcombine_u8(
7263         vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
7264 
7265     ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
7266     ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
7267                            masks_clamp_2_3);
7268   }
7269 #endif  // GEMMLOWP_NEON
7270   // Leftover loop: handle one value at a time with scalar code.
7271   for (; c < size; ++c) {
7272     const uint8 input_val_u8 = input_data[c];
7273     const int16 input_val_centered =
7274         static_cast<int16>(input_val_u8) - input_zero_point;
7275     uint8 output_val;
7276     if (input_val_centered < -input_range_radius) {
7277       output_val = 0;
7278     } else if (input_val_centered > input_range_radius) {
7279       output_val = 255;
7280     } else {
7281       using gemmlowp::SaturatingRoundingDoublingHighMul;
7282       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7283           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7284           static_cast<int16>(input_multiplier));
7285       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7286       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7287       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7288       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
7289       using gemmlowp::RoundingDivideByPOT;
7290       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
7291       output_val_s16 += output_zero_point;
7292       if (output_val_s16 == 256) {
7293         output_val_s16 = 255;
7294       }
7295       TFLITE_DCHECK_GE(output_val_s16, 0);
7296       TFLITE_DCHECK_LE(output_val_s16, 255);
7297       output_val = static_cast<uint8>(output_val_s16);
7298     }
7299     output_data[c] = output_val;
7300   }
7301 }
7302 
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7303 inline void Tanh16bitPrecision(const TanhParams& params,
7304                                const RuntimeShape& input_shape,
7305                                const int8* input_data,
7306                                const RuntimeShape& output_shape,
7307                                int8* output_data) {
7308   // Note that this is almost the exact same code as in Logistic().
7309   ruy::profiler::ScopeLabel label("Tanh/Int8");
7310   const int32 input_zero_point = params.input_zero_point;
7311   const int32 input_range_radius = params.input_range_radius;
7312   const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
7313   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7314   const int size = MatchingFlatSize(input_shape, output_shape);
7315 
7316   int c = 0;
7317 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7318 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7319 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7320 #ifdef GEMMLOWP_NEON
7321   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7322   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7323 
7324   // Handle 32 values at a time
7325   for (; c <= size - 32; c += 32) {
7326     // Read input int8 values, cast to int16 and subtract input_zero_point
7327     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7328     const int16x8x2_t input_val_centered_0_1 =
7329         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7330     const int16x8x2_t input_val_centered_2_3 =
7331         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7332 
7333     // Prepare the bit masks that we will use at the end to implement the logic
7334     // that was expressed in the scalar code with branching:
7335     //   if (input_val_centered < -input_range_radius) {
7336     //     output_val = -128;
7337     //   } else if (input_val_centered > input_range_radius) {
7338     //     output_val = 127;
7339     //   } else {
7340     //     ...
7341     uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7342         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7343     uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7344         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7345 
7346     int16x8x4_t input_val_rescaled = SaturatingRounding(
7347         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7348         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7349         input_left_shift, input_multiplier);
7350 
7351     int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
7352 
7353     // Cast output values to uint8, saturating
7354     int8x16_t output_val_s8_0_1 = vcombine_s8(
7355         vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7356     int8x16_t output_val_s8_2_3 = vcombine_s8(
7357         vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7358 
7359     ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7360     ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7361                            masks_clamp_2_3);
7362   }
7363 #endif  // GEMMLOWP_NEON
7364   // Leftover loop: handle one value at a time with scalar code.
7365   for (; c < size; ++c) {
7366     const int8 input_val_s8 = input_data[c];
7367     const int16 input_val_centered =
7368         static_cast<int16>(input_val_s8) - input_zero_point;
7369     int8 output_val;
7370     if (input_val_centered <= -input_range_radius) {
7371       output_val = -128;
7372     } else if (input_val_centered >= input_range_radius) {
7373       output_val = 127;
7374     } else {
7375       using gemmlowp::SaturatingRoundingDoublingHighMul;
7376       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7377           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7378           static_cast<int16>(input_multiplier));
7379       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7380       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7381       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7382       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
7383       using gemmlowp::RoundingDivideByPOT;
7384       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
7385       if (output_val_s16 == 128) {
7386         output_val_s16 = 127;
7387       }
7388       TFLITE_DCHECK_GE(output_val_s16, -128);
7389       TFLITE_DCHECK_LE(output_val_s16, 127);
7390       output_val = static_cast<int8>(output_val_s16);
7391     }
7392     output_data[c] = output_val;
7393   }
7394 }
7395 
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)7396 inline void Logistic16bitPrecision(const LogisticParams& params,
7397                                    const RuntimeShape& input_shape,
7398                                    const uint8* input_data,
7399                                    const RuntimeShape& output_shape,
7400                                    uint8* output_data) {
7401   ruy::profiler::ScopeLabel label("Logistic/Uint8");
7402   const int32 input_zero_point = params.input_zero_point;
7403   const int32 input_range_radius = params.input_range_radius;
7404   const int32 input_multiplier = params.input_multiplier;
7405   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7406   const int size = MatchingFlatSize(input_shape, output_shape);
7407 
7408   int c = 0;
7409 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7410 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7411 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7412 #ifdef GEMMLOWP_NEON
7413   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7414   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7415 
7416   // Handle 32 values at a time
7417   for (; c <= size - 32; c += 32) {
7418     // Read input uint8 values, cast to int16 and subtract input_zero_point
7419     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7420     const int16x8x2_t input_val_centered_0_1 =
7421         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7422     const int16x8x2_t input_val_centered_2_3 =
7423         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7424 
7425     // Prepare the bit masks that we will use at the end to implement the logic
7426     // that was expressed in the scalar code with branching:
7427     //   if (input_val_centered < -input_range_radius) {
7428     //     output_val = 0;
7429     //   } else if (input_val_centered > input_range_radius) {
7430     //     output_val = 255;
7431     //   } else {
7432     //     ...
7433     uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
7434         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7435     uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
7436         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7437 
7438     int16x8x4_t input_val_rescaled = SaturatingRounding(
7439         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7440         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7441         input_left_shift, input_multiplier);
7442 
7443     int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7444 
7445     // Cast output values to uint8, saturating
7446     uint8x16_t output_val_u8_0_1 = vcombine_u8(
7447         vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
7448     uint8x16_t output_val_u8_2_3 = vcombine_u8(
7449         vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
7450 
7451     ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
7452     ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
7453                            masks_clamp_2_3);
7454   }
7455 #endif  // GEMMLOWP_NEON
7456   // Leftover loop: handle one value at a time with scalar code.
7457   for (; c < size; ++c) {
7458     const uint8 input_val_u8 = input_data[c];
7459     const int16 input_val_centered =
7460         static_cast<int16>(input_val_u8) - input_zero_point;
7461     uint8 output_val;
7462     if (input_val_centered < -input_range_radius) {
7463       output_val = 0;
7464     } else if (input_val_centered > input_range_radius) {
7465       output_val = 255;
7466     } else {
7467       using gemmlowp::SaturatingRoundingDoublingHighMul;
7468       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7469           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7470           static_cast<int16>(input_multiplier));
7471       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7472       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7473       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7474       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7475       using gemmlowp::RoundingDivideByPOT;
7476       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7477       if (output_val_s16 == 256) {
7478         output_val_s16 = 255;
7479       }
7480       TFLITE_DCHECK_GE(output_val_s16, 0);
7481       TFLITE_DCHECK_LE(output_val_s16, 255);
7482       output_val = static_cast<uint8>(output_val_s16);
7483     }
7484     output_data[c] = output_val;
7485   }
7486 }
7487 
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7488 inline void Logistic16bitPrecision(const LogisticParams& params,
7489                                    const RuntimeShape& input_shape,
7490                                    const int8* input_data,
7491                                    const RuntimeShape& output_shape,
7492                                    int8* output_data) {
7493   ruy::profiler::ScopeLabel label("Logistic/Int8");
7494   const int32 input_zero_point = params.input_zero_point;
7495   const int32 input_range_radius = params.input_range_radius;
7496   const int32 input_multiplier = params.input_multiplier;
7497   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7498   const int size = MatchingFlatSize(input_shape, output_shape);
7499 
7500   int c = 0;
7501   const int16 output_zero_point = 128;
7502 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7503 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7504 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7505 #ifdef GEMMLOWP_NEON
7506   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7507   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7508   const int16x8_t output_zero_point_dup = vdupq_n_s16(output_zero_point);
7509 
7510   // Handle 32 values at a time
7511   for (; c <= size - 32; c += 32) {
7512     // Read input int8 values, cast to int16 and subtract input_zero_point
7513     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7514     const int16x8x2_t input_val_centered_0_1 =
7515         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7516     const int16x8x2_t input_val_centered_2_3 =
7517         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7518 
7519     // Prepare the bit masks that we will use at the end to implement the logic
7520     // that was expressed in the scalar code with branching:
7521     //   if (input_val_centered < -input_range_radius) {
7522     //     output_val = -128;
7523     //   } else if (input_val_centered > input_range_radius) {
7524     //     output_val = 127;
7525     //   } else {
7526     //     ...
7527     uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7528         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7529     uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7530         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7531 
7532     int16x8x4_t input_val_rescaled = SaturatingRounding(
7533         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7534         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7535         input_left_shift, input_multiplier);
7536 
7537     int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7538 
7539     // Substract output zero point.
7540     output_val_s16.val[0] =
7541         vsubq_s16(output_val_s16.val[0], output_zero_point_dup);
7542     output_val_s16.val[1] =
7543         vsubq_s16(output_val_s16.val[1], output_zero_point_dup);
7544     output_val_s16.val[2] =
7545         vsubq_s16(output_val_s16.val[2], output_zero_point_dup);
7546     output_val_s16.val[3] =
7547         vsubq_s16(output_val_s16.val[3], output_zero_point_dup);
7548 
7549     // Cast output values to int8, saturating
7550     int8x16_t output_val_s8_0_1 = vcombine_s8(
7551         vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7552     int8x16_t output_val_s8_2_3 = vcombine_s8(
7553         vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7554 
7555     ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7556     ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7557                            masks_clamp_2_3);
7558   }
7559 #endif  // GEMMLOWP_NEON
7560   // Leftover loop: handle one value at a time with scalar code.
7561   for (; c < size; ++c) {
7562     const int8 input_val_s8 = input_data[c];
7563     const int16 input_val_centered =
7564         static_cast<int16>(input_val_s8) - input_zero_point;
7565     int8 output_val;
7566     if (input_val_centered < -input_range_radius) {
7567       output_val = -128;
7568     } else if (input_val_centered > input_range_radius) {
7569       output_val = 127;
7570     } else {
7571       using gemmlowp::SaturatingRoundingDoublingHighMul;
7572       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7573           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7574           static_cast<int16>(input_multiplier));
7575       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7576       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7577       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7578       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7579       using gemmlowp::RoundingDivideByPOT;
7580       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7581       output_val_s16 -= output_zero_point;
7582       if (output_val_s16 == 128) {
7583         output_val_s16 = 127;
7584       }
7585       TFLITE_DCHECK_GE(output_val_s16, -128);
7586       TFLITE_DCHECK_LE(output_val_s16, 127);
7587       output_val = static_cast<int8>(output_val_s16);
7588     }
7589     output_data[c] = output_val;
7590   }
7591 }
7592 
7593 // Transpose2D only deals with typical 2D matrix transpose ops.
7594 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
7595 // left to right (down the rows) of the input, and then from top to bottom.
7596 template <typename T>
Transpose2D(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7597 inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
7598                         const RuntimeShape& output_shape, T* output_data) {
7599   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7600   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7601 
7602   const int d0 = input_shape.DimsData()[0];
7603   const int d1 = input_shape.DimsData()[1];
7604   const int kLines = 4;
7605   const int kSkipSize = (kLines - 1) * d1;
7606 
7607   const T* input = input_data;
7608 
7609   int i = 0;
7610   for (; i <= d0 - kLines; i += kLines) {
7611     T* output = output_data + i;
7612 
7613     const T* input_ptr = input;
7614     optimized_ops_preload_l1_keep(input_ptr);
7615     input_ptr += d1;
7616     optimized_ops_preload_l1_keep(input_ptr);
7617     input_ptr += d1;
7618     optimized_ops_preload_l1_keep(input_ptr);
7619     input_ptr += d1;
7620     optimized_ops_preload_l1_keep(input_ptr);
7621 
7622     int j = 0;
7623     for (; j <= d1 - kLines; j += kLines) {
7624       input_ptr = input;
7625       const T a00 = input_ptr[0];
7626       const T a01 = input_ptr[1];
7627       const T a02 = input_ptr[2];
7628       const T a03 = input_ptr[3];
7629       input_ptr += d1;
7630       const T a10 = input_ptr[0];
7631       const T a11 = input_ptr[1];
7632       const T a12 = input_ptr[2];
7633       const T a13 = input_ptr[3];
7634       input_ptr += d1;
7635       const T a20 = input_ptr[0];
7636       const T a21 = input_ptr[1];
7637       const T a22 = input_ptr[2];
7638       const T a23 = input_ptr[3];
7639       input_ptr += d1;
7640       const T a30 = input_ptr[0];
7641       const T a31 = input_ptr[1];
7642       const T a32 = input_ptr[2];
7643       const T a33 = input_ptr[3];
7644 
7645       output[0] = a00;
7646       output[1] = a10;
7647       output[2] = a20;
7648       output[3] = a30;
7649       output += d0;
7650 
7651       output[0] = a01;
7652       output[1] = a11;
7653       output[2] = a21;
7654       output[3] = a31;
7655       output += d0;
7656 
7657       output[0] = a02;
7658       output[1] = a12;
7659       output[2] = a22;
7660       output[3] = a32;
7661       output += d0;
7662 
7663       output[0] = a03;
7664       output[1] = a13;
7665       output[2] = a23;
7666       output[3] = a33;
7667       output += d0;
7668 
7669       input += kLines;
7670     }
7671     if (j == d1) {
7672       input += kSkipSize;
7673     } else {
7674       for (int p = 0; p < kLines; ++p) {
7675         for (int q = 0; q < d1 - j; ++q) {
7676           *(output + q * d0 + p) = *(input + p * d1 + q);
7677         }
7678       }
7679       input += (d1 - j) + kSkipSize;
7680     }
7681   }
7682   for (; i < d0; ++i) {
7683     T* output = output_data + i;
7684     for (int j = 0; j < d1; ++j) {
7685       *output = *input;
7686       output += d0;
7687       ++input;
7688     }
7689   }
7690 }
7691 
7692 template <>
Transpose2D(const RuntimeShape & input_shape,const int32_t * input_data,const RuntimeShape & output_shape,int32_t * output_data)7693 inline void Transpose2D(const RuntimeShape& input_shape,
7694                         const int32_t* input_data,
7695                         const RuntimeShape& output_shape,
7696                         int32_t* output_data) {
7697   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7698   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7699 
7700   const int d0 = input_shape.DimsData()[0];
7701   const int d1 = input_shape.DimsData()[1];
7702 #ifdef USE_NEON
7703   const int kLines = 4;
7704   const int kSkipSize = (kLines - 1) * d1;
7705 #endif
7706 
7707   const int32_t* input = input_data;
7708 
7709   int i = 0;
7710 #ifdef USE_NEON
7711   for (; i <= d0 - kLines; i += kLines) {
7712     int32_t* output = output_data + i;
7713 
7714     const int32_t* input_ptr = input;
7715     optimized_ops_preload_l1_keep(input_ptr);
7716     input_ptr += d1;
7717     optimized_ops_preload_l1_keep(input_ptr);
7718     input_ptr += d1;
7719     optimized_ops_preload_l1_keep(input_ptr);
7720     input_ptr += d1;
7721     optimized_ops_preload_l1_keep(input_ptr);
7722 
7723     int j = 0;
7724     for (; j <= d1 - kLines; j += kLines) {
7725       input_ptr = input;
7726       int32x4_t a0 = vld1q_s32(input);
7727       input_ptr += d1;
7728       int32x4_t a1 = vld1q_s32(input_ptr);
7729       input_ptr += d1;
7730       int32x4_t a2 = vld1q_s32(input_ptr);
7731       input_ptr += d1;
7732       int32x4_t a3 = vld1q_s32(input_ptr);
7733 
7734       int32x4x2_t tmp1 = vuzpq_s32(a0, a2);
7735       int32x4x2_t tmp2 = vuzpq_s32(a1, a3);
7736       int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]);
7737       int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]);
7738 
7739       vst1q_s32(output, tmp3.val[0]);
7740       output += d0;
7741       vst1q_s32(output, tmp4.val[0]);
7742       output += d0;
7743       vst1q_s32(output, tmp3.val[1]);
7744       output += d0;
7745       vst1q_s32(output, tmp4.val[1]);
7746       output += d0;
7747       input += kLines;
7748     }
7749     if (j == d1) {
7750       input += kSkipSize;
7751     } else {
7752       for (int p = 0; p < kLines; ++p) {
7753         for (int q = 0; q < d1 - j; ++q) {
7754           *(output + q * d0 + p) = *(input + p * d1 + q);
7755         }
7756       }
7757       input += (d1 - j) + kSkipSize;
7758     }
7759   }
7760 #endif
7761   for (; i < d0; ++i) {
7762     int32_t* output = output_data + i;
7763     for (int j = 0; j < d1; ++j) {
7764       *output = *input;
7765       output += d0;
7766       ++input;
7767     }
7768   }
7769 }
7770 
7771 // TODO(b/173718660): see if we can reduce the number
7772 // of lines of code in branching without affecting latency.
7773 template <typename T>
Transpose3D(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7774 inline void Transpose3D(const TransposeParams& params,
7775                         const RuntimeShape& input_shape, const T* input_data,
7776                         const RuntimeShape& output_shape, T* output_data) {
7777   int s1, s2, s3;
7778   s1 = input_shape.Dims(0);
7779   s2 = input_shape.Dims(1);
7780   s3 = input_shape.Dims(2);
7781 
7782   int p1, p2, p3;
7783   if (params.perm[0] == 2) {
7784     p1 = 1;
7785   } else if (params.perm[1] == 2) {
7786     p2 = 1;
7787   } else {
7788     p3 = 1;
7789   }
7790 
7791   if (params.perm[0] == 1) {
7792     p1 = s3;
7793   } else if (params.perm[1] == 1) {
7794     p2 = s3;
7795   } else {
7796     p3 = s3;
7797   }
7798 
7799   if (params.perm[0] == 0) {
7800     p1 = s2 * s3;
7801   } else if (params.perm[1] == 0) {
7802     p2 = s2 * s3;
7803   } else {
7804     p3 = s2 * s3;
7805   }
7806 
7807   int o_s[3];
7808   o_s[0] = input_shape.Dims(params.perm[0]);
7809   o_s[1] = input_shape.Dims(params.perm[1]);
7810   o_s[2] = input_shape.Dims(params.perm[2]);
7811 
7812   for (int i1 = 0; i1 < o_s[0]; ++i1) {
7813     for (int i2 = 0; i2 < o_s[1]; ++i2) {
7814       for (int i3 = 0; i3 < o_s[2]; ++i3) {
7815         const int i = i1 * p1 + i2 * p2 + i3 * p3;
7816         const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
7817         output_data[o] = input_data[i];
7818       }
7819     }
7820   }
7821 }
7822 
7823 template <typename T, int N>
TransposeImpl(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7824 void TransposeImpl(const TransposeParams& params,
7825                    const RuntimeShape& input_shape, const T* input_data,
7826                    const RuntimeShape& output_shape, T* output_data) {
7827   const int dims_cnt = input_shape.DimensionsCount();
7828 
7829   int dim0, dim1;
7830   if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
7831                                                &dim1)) {
7832     Transpose2D(RuntimeShape({dim0, dim1}), input_data,
7833                 RuntimeShape({dim1, dim0}), output_data);
7834     return;
7835   }
7836 
7837   // TODO(b/141217325): notably Eigen is better suited for
7838   // larger inputs whereas Transpose3D is generally
7839   // better for smaller ones.
7840   //
7841   // E.g. on Nexus 5, Eigen is better for size 96^3 and up
7842   // and Transpose3D is better for 72^3 and down.
7843   //
7844   // 96^3 is not mobile-friendly for certain usecases
7845   // (e.g. model used in beam search for seq2seq) but is in others.
7846   // Consider tradeoffs.
7847   if (dims_cnt == 3) {
7848     Transpose3D(params, input_shape, input_data, output_shape, output_data);
7849     return;
7850   }
7851 
7852   // Reroute to the reference version if an optimized method for the given data
7853   // is not available.
7854   reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
7855                                  output_data);
7856 }
7857 
7858 template <typename T, int N = 5>
Transpose(const TransposeParams & unshrinked_params,const RuntimeShape & unshrinked_input_shape,const T * input_data,const RuntimeShape & unshrinked_output_shape,T * output_data)7859 void Transpose(const TransposeParams& unshrinked_params,
7860                const RuntimeShape& unshrinked_input_shape, const T* input_data,
7861                const RuntimeShape& unshrinked_output_shape, T* output_data) {
7862   ruy::profiler::ScopeLabel label("Transpose");
7863 
7864   const int output_size = unshrinked_output_shape.DimensionsCount();
7865   TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
7866   TFLITE_DCHECK_LE(output_size, N);
7867   TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
7868 
7869   RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
7870   RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
7871   TransposeParams shrinked_params = unshrinked_params;
7872 
7873   // Reduce any dimensions that have one size. Lower transpose op usually
7874   // performs better since memory access patterns will be improved.
7875   transpose_utils::RemoveOneSizeDimensions(
7876       &shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
7877 
7878   // Handle identity cases.
7879   // TODO(b/140779653): Add an optimization pass in the conversion process to
7880   // remove transpose op nodes where they do nothing like the below one.
7881   bool identical = true;
7882   for (int i = 0; i < shrinked_params.perm_count; ++i) {
7883     if (shrinked_params.perm[i] != i) {
7884       identical = false;
7885       break;
7886     }
7887   }
7888   if (identical) {
7889     memcpy(output_data, input_data,
7890            unshrinked_input_shape.FlatSize() * sizeof(T));
7891     return;
7892   }
7893 
7894   // Reduce dimensions by flattening.
7895   if (shrinked_params.perm[0] == 0 && output_size >= 3) {
7896     RuntimeShape non_flatten_input_shape;
7897     RuntimeShape non_flatten_output_shape;
7898     TransposeParams non_flatten_params;
7899     const int total_size = shrinked_input_shape.FlatSize();
7900     const int non_flatten_size = transpose_utils::Flatten(
7901         shrinked_input_shape, shrinked_output_shape, shrinked_params,
7902         &non_flatten_input_shape, &non_flatten_output_shape,
7903         &non_flatten_params);
7904     TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
7905 
7906     for (int i = 0; i < total_size; i += non_flatten_size) {
7907       TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
7908                           input_data + i, non_flatten_output_shape,
7909                           output_data + i);
7910     }
7911     return;
7912   }
7913 
7914   // Call non-flattened case.
7915   TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
7916                       shrinked_output_shape, output_data);
7917 }
7918 
7919 // Assume input1 & input2 have the same scale & zero point.
MaximumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7920 inline void MaximumElementwise(int size, const ArithmeticParams& params,
7921                                const int8* input1_data, const int8* input2_data,
7922                                int8* output_data) {
7923   ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit");
7924   int i = 0;
7925 #ifdef USE_NEON
7926   for (; i <= size - 16; i += 16) {
7927     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7928     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7929     const int8x16_t max_data =
7930         vmaxq_s8(input1_val_original, input2_val_original);
7931     vst1q_s8(output_data + i, max_data);
7932   }
7933 #endif  // USE_NEON
7934   for (; i < size; ++i) {
7935     const int8 input1_val = input1_data[i];
7936     const int8 input2_val = input2_data[i];
7937     output_data[i] = std::max(input1_val, input2_val);
7938   }
7939 }
7940 
MaximumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7941 inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
7942                                    int8 input1_data, const int8* input2_data,
7943                                    int8* output_data) {
7944   ruy::profiler::ScopeLabel label("MaximumScalarBroadcastInt8/8bit");
7945   int i = 0;
7946 
7947 #ifdef USE_NEON
7948   const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7949   for (; i <= size - 16; i += 16) {
7950     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7951     const int8x16_t max_data =
7952         vmaxq_s8(input1_val_original, input2_val_original);
7953     vst1q_s8(output_data + i, max_data);
7954   }
7955 #endif  // USE_NEON
7956   for (; i < size; ++i) {
7957     const int8 input2_val = input2_data[i];
7958     output_data[i] = std::max(input1_data, input2_val);
7959   }
7960 }
7961 
7962 // Assume input1 & input2 have the same scale & zero point.
MinimumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7963 inline void MinimumElementwise(int size, const ArithmeticParams& params,
7964                                const int8* input1_data, const int8* input2_data,
7965                                int8* output_data) {
7966   ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit");
7967   int i = 0;
7968 #ifdef USE_NEON
7969   for (; i <= size - 16; i += 16) {
7970     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7971     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7972     const int8x16_t min_data =
7973         vminq_s8(input1_val_original, input2_val_original);
7974     vst1q_s8(output_data + i, min_data);
7975   }
7976 #endif  // USE_NEON
7977   for (; i < size; ++i) {
7978     const int8 input1_val = input1_data[i];
7979     const int8 input2_val = input2_data[i];
7980     output_data[i] = std::min(input1_val, input2_val);
7981   }
7982 }
7983 
MinimumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7984 inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
7985                                    int8 input1_data, const int8* input2_data,
7986                                    int8* output_data) {
7987   ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit");
7988   int i = 0;
7989 
7990 #ifdef USE_NEON
7991   const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7992   for (; i <= size - 16; i += 16) {
7993     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7994     const int8x16_t min_data =
7995         vminq_s8(input1_val_original, input2_val_original);
7996     vst1q_s8(output_data + i, min_data);
7997   }
7998 #endif  // USE_NEON
7999   for (; i < size; ++i) {
8000     const int8 input2_val = input2_data[i];
8001     output_data[i] = std::min(input1_data, input2_val);
8002   }
8003 }
8004 
8005 template <typename Op>
BroadcastMaximumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)8006 inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
8007                                      const RuntimeShape& input1_shape,
8008                                      const int8* input1_data,
8009                                      const RuntimeShape& input2_shape,
8010                                      const int8* input2_data,
8011                                      const RuntimeShape& output_shape,
8012                                      int8* output_data, Op op) {
8013   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
8014     return reference_ops::MaximumMinimumBroadcastSlow(
8015         input1_shape, input1_data, input2_shape, input2_data, output_shape,
8016         output_data, op);
8017   }
8018 
8019   BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
8020                           input2_data, output_shape, output_data,
8021                           MaximumElementwise, MaximumScalarBroadcast);
8022 }
8023 
8024 template <typename Op>
BroadcastMinimumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)8025 inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
8026                                      const RuntimeShape& input1_shape,
8027                                      const int8* input1_data,
8028                                      const RuntimeShape& input2_shape,
8029                                      const int8* input2_data,
8030                                      const RuntimeShape& output_shape,
8031                                      int8* output_data, Op op) {
8032   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
8033     return reference_ops::MaximumMinimumBroadcastSlow(
8034         input1_shape, input1_data, input2_shape, input2_data, output_shape,
8035         output_data, op);
8036   }
8037 
8038   BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
8039                           input2_data, output_shape, output_data,
8040                           MinimumElementwise, MinimumScalarBroadcast);
8041 }
8042 
8043 template <typename T>
CumsumImpl(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)8044 void CumsumImpl(const T* input_data, const RuntimeShape& shape, int axis,
8045                 bool exclusive, bool reverse, T* output_data) {
8046   Eigen::array<Eigen::DenseIndex, 3> dims = {1, 1, 1};
8047 
8048   for (int i = 0; i < axis; ++i) {
8049     dims[0] *= shape.Dims(i);
8050   }
8051   dims[1] = shape.Dims(axis);
8052   for (int i = axis + 1; i < shape.DimensionsCount(); ++i) {
8053     dims[2] *= shape.Dims(i);
8054   }
8055 
8056   typedef Eigen::TensorMap<
8057       Eigen::Tensor<const T, 3, Eigen::RowMajor, Eigen::DenseIndex>,
8058       Eigen::Aligned>
8059       ConstTensor;
8060   typedef Eigen::TensorMap<
8061       Eigen::Tensor<T, 3, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
8062       Tensor;
8063   ConstTensor input(input_data, dims);
8064   Tensor output(output_data, dims);
8065 
8066   if (reverse) {
8067     Eigen::array<bool, 3> reverse_idx = {false, true, false};
8068     output =
8069         input.reverse(reverse_idx).cumsum(1, exclusive).reverse(reverse_idx);
8070   } else {
8071     output = input.cumsum(1, exclusive);
8072   }
8073 }
8074 
8075 template <typename T>
CumSum(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)8076 void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
8077             bool exclusive, bool reverse, T* output_data) {
8078   const int dim = shape.DimensionsCount();
8079   TFLITE_DCHECK_GE(dim, 1);
8080   CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
8081 }
8082 
PReluScalarBroadcast(int size,const ArithmeticParams & params,float alpha,const float * input_data,float * output_data)8083 inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
8084                                  float alpha, const float* input_data,
8085                                  float* output_data) {
8086   ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
8087   int i = 0;
8088 
8089 #ifdef USE_NEON
8090   const float32x4_t zero_dup = vdupq_n_f32(0.0f);
8091   const float32x4_t alpha_dup = vdupq_n_f32(alpha);
8092   for (; i <= size - 16; i += 16) {
8093     const float32x4_t input1 = vld1q_f32(input_data + i);
8094     const float32x4_t input2 = vld1q_f32(input_data + i + 4);
8095     const float32x4_t input3 = vld1q_f32(input_data + i + 8);
8096     const float32x4_t input4 = vld1q_f32(input_data + i + 12);
8097 
8098     const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
8099     const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
8100     const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
8101     const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
8102 
8103     const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
8104     const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
8105     const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
8106     const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
8107 
8108     const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
8109     vst1q_f32(output_data + i, result1);
8110     const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
8111     vst1q_f32(output_data + i + 4, result2);
8112     const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
8113     vst1q_f32(output_data + i + 8, result3);
8114     const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
8115     vst1q_f32(output_data + i + 12, result4);
8116   }
8117 
8118   for (; i <= size - 4; i += 4) {
8119     const float32x4_t input = vld1q_f32(input_data + i);
8120     const float32x4_t temp = vmulq_f32(input, alpha_dup);
8121     const uint32x4_t mask = vcgeq_f32(input, zero_dup);
8122     const float32x4_t result = vbslq_f32(mask, input, temp);
8123     vst1q_f32(output_data + i, result);
8124   }
8125 #endif  // USE_NEON
8126   for (; i < size; ++i) {
8127     const float input = input_data[i];
8128     output_data[i] = input >= 0.f ? input : input * alpha;
8129   }
8130 }
8131 
PReluElementWise(int flat_size,const ArithmeticParams & params,const float * alpha_data,const float * input_data,float * output_data)8132 inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
8133                              const float* alpha_data, const float* input_data,
8134                              float* output_data) {
8135   ruy::profiler::ScopeLabel label("PreluElementWise/float");
8136 
8137   int i = 0;
8138 #ifdef USE_NEON
8139   const float32x4_t zero_dup = vdupq_n_f32(0.0f);
8140   for (; i <= flat_size - 16; i += 16) {
8141     const float32x4_t input1 = vld1q_f32(input_data + i);
8142     const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
8143     const float32x4_t input2 = vld1q_f32(input_data + i + 4);
8144     const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
8145     const float32x4_t input3 = vld1q_f32(input_data + i + 8);
8146     const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
8147     const float32x4_t input4 = vld1q_f32(input_data + i + 12);
8148     const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
8149 
8150     const float32x4_t temp1 = vmulq_f32(input1, alpha1);
8151     const float32x4_t temp2 = vmulq_f32(input2, alpha2);
8152     const float32x4_t temp3 = vmulq_f32(input3, alpha3);
8153     const float32x4_t temp4 = vmulq_f32(input4, alpha4);
8154 
8155     const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
8156     const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
8157     const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
8158     const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
8159 
8160     const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
8161     vst1q_f32(output_data + i, result1);
8162     const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
8163     vst1q_f32(output_data + i + 4, result2);
8164     const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
8165     vst1q_f32(output_data + i + 8, result3);
8166     const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
8167     vst1q_f32(output_data + i + 12, result4);
8168   }
8169 
8170   for (; i <= flat_size - 4; i += 4) {
8171     const float32x4_t input = vld1q_f32(input_data + i);
8172     const float32x4_t alpha = vld1q_f32(alpha_data + i);
8173 
8174     const float32x4_t temp = vmulq_f32(input, alpha);
8175     const uint32x4_t mask = vcgeq_f32(input, zero_dup);
8176     const float32x4_t result = vbslq_f32(mask, input, temp);
8177     vst1q_f32(output_data + i, result);
8178   }
8179 #endif  // USE_NEON
8180   for (; i < flat_size; ++i) {
8181     const float input = input_data[i];
8182     const float alpha = alpha_data[i];
8183     output_data[i] = input >= 0.f ? input : input * alpha;
8184   }
8185 }
8186 
BroadcastPReluDispatch(const ArithmeticParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & alpha_shape,const float * alpha_data,const RuntimeShape & output_shape,float * output_data,float (* func)(float,float))8187 inline void BroadcastPReluDispatch(
8188     const ArithmeticParams& params, const RuntimeShape& input_shape,
8189     const float* input_data, const RuntimeShape& alpha_shape,
8190     const float* alpha_data, const RuntimeShape& output_shape,
8191     float* output_data, float (*func)(float, float)) {
8192   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
8193     return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
8194         input_shape, input_data, alpha_shape, alpha_data, output_shape,
8195         output_data, func);
8196   }
8197 
8198   BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
8199                           alpha_data, output_shape, output_data,
8200                           PReluElementWise, PReluScalarBroadcast);
8201 }
8202 
8203 }  // namespace optimized_ops
8204 }  // namespace tflite
8205 
8206 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8207 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8208 #pragma GCC diagnostic pop
8209 #endif
8210 
8211 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
8212