1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include <algorithm>
22 #include <array>
23 #include <cmath>
24 #include <cstring>
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <type_traits>
29 
30 #include "Eigen/Core"
31 #include "fixedpoint/fixedpoint.h"
32 #include "ruy/profiler/instrumentation.h"  // from @ruy
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/common.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/reference/add.h"
37 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
38 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
39 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
40 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
41 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
42 #include "tensorflow/lite/kernels/internal/reference/cast.h"
43 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
44 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
45 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
46 #include "tensorflow/lite/kernels/internal/reference/conv.h"
47 #include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
48 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
49 #include "tensorflow/lite/kernels/internal/reference/div.h"
50 #include "tensorflow/lite/kernels/internal/reference/elu.h"
51 #include "tensorflow/lite/kernels/internal/reference/exp.h"
52 #include "tensorflow/lite/kernels/internal/reference/fill.h"
53 #include "tensorflow/lite/kernels/internal/reference/floor.h"
54 #include "tensorflow/lite/kernels/internal/reference/floor_div.h"
55 #include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
56 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
57 #include "tensorflow/lite/kernels/internal/reference/gather.h"
58 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
59 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
60 #include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
61 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
62 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
63 #include "tensorflow/lite/kernels/internal/reference/mul.h"
64 #include "tensorflow/lite/kernels/internal/reference/neg.h"
65 #include "tensorflow/lite/kernels/internal/reference/pad.h"
66 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
67 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
68 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
69 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
70 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
71 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
72 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
73 #include "tensorflow/lite/kernels/internal/reference/round.h"
74 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
75 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
76 #include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
77 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
78 #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
79 #include "tensorflow/lite/kernels/internal/reference/sub.h"
80 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
81 #include "tensorflow/lite/kernels/internal/reference/transpose.h"
82 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
83 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
84 #include "tensorflow/lite/kernels/internal/tensor.h"
85 #include "tensorflow/lite/kernels/internal/types.h"
86 namespace tflite {
87 
88 namespace reference_ops {
89 
90 template <typename T>
Relu(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)91 inline void Relu(const RuntimeShape& input_shape, const T* input_data,
92                  const RuntimeShape& output_shape, T* output_data) {
93   const int flat_size = MatchingFlatSize(input_shape, output_shape);
94   for (int i = 0; i < flat_size; ++i) {
95     const T val = input_data[i];
96     const T lower = 0;
97     const T clamped = val < lower ? lower : val;
98     output_data[i] = clamped;
99   }
100 }
101 
102 template <typename T>
Relu1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)103 inline void Relu1(const RuntimeShape& input_shape, const T* input_data,
104                   const RuntimeShape& output_shape, T* output_data) {
105   ruy::profiler::ScopeLabel label("Relu1 (not fused)");
106   const int flat_size = MatchingFlatSize(input_shape, output_shape);
107   for (int i = 0; i < flat_size; ++i) {
108     const T val = input_data[i];
109     const T upper = 1;
110     const T lower = -1;
111     const T clamped = val > upper ? upper : val < lower ? lower : val;
112     output_data[i] = clamped;
113   }
114 }
115 
Relu6(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)116 inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
117                   const RuntimeShape& output_shape, float* output_data) {
118   ruy::profiler::ScopeLabel label("Relu6 (not fused)");
119   const int flat_size = MatchingFlatSize(input_shape, output_shape);
120   for (int i = 0; i < flat_size; ++i) {
121     const float val = input_data[i];
122     const float upper = 6;
123     const float lower = 0;
124     const float clamped = val > upper ? upper : val < lower ? lower : val;
125     output_data[i] = clamped;
126   }
127 }
128 
129 template <typename T>
ReluX(const tflite::ReluParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)130 inline void ReluX(const tflite::ReluParams& params,
131                   const RuntimeShape& input_shape, const T* input_data,
132                   const RuntimeShape& output_shape, T* output_data) {
133   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
134   const int flat_size = MatchingFlatSize(input_shape, output_shape);
135   for (int i = 0; i < flat_size; ++i) {
136     const int32 val = static_cast<int32_t>(input_data[i]);
137     int32 clamped = params.output_offset +
138                     MultiplyByQuantizedMultiplier(val - params.input_offset,
139                                                   params.output_multiplier,
140                                                   params.output_shift);
141     clamped = std::max(params.quantized_activation_min, clamped);
142     clamped = std::min(params.quantized_activation_max, clamped);
143     output_data[i] = static_cast<T>(clamped);
144   }
145 }
146 
147 template <typename T>
ReluX(const tflite::ActivationParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)148 inline void ReluX(const tflite::ActivationParams& params,
149                   const RuntimeShape& input_shape, const T* input_data,
150                   const RuntimeShape& output_shape, T* output_data) {
151   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
152   const int flat_size = MatchingFlatSize(input_shape, output_shape);
153   const T max_value = params.quantized_activation_max;
154   const T min_value = params.quantized_activation_min;
155   for (int i = 0; i < flat_size; ++i) {
156     const T val = input_data[i];
157     const T clamped = val > max_value   ? max_value
158                       : val < min_value ? min_value
159                                         : val;
160     output_data[i] = clamped;
161   }
162 }
163 
164 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
165 // dimensionality if the runtime code does a single loop over one dimension
166 // that handles broadcasting as the base case. The code generator would then
167 // generate max(D1, D2) nested for loops.
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)168 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
169                                  const RuntimeShape& unswitched_input1_shape,
170                                  const uint8* unswitched_input1_data,
171                                  const RuntimeShape& unswitched_input2_shape,
172                                  const uint8* unswitched_input2_data,
173                                  const RuntimeShape& output_shape,
174                                  uint8* output_data) {
175   ArithmeticParams switched_params = unswitched_params;
176   switched_params.input1_offset = unswitched_params.input2_offset;
177   switched_params.input2_offset = unswitched_params.input1_offset;
178 
179   const bool use_unswitched =
180       unswitched_params.broadcast_category ==
181       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
182 
183   const ArithmeticParams& params =
184       use_unswitched ? unswitched_params : switched_params;
185   const uint8* input1_data =
186       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
187   const uint8* input2_data =
188       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
189 
190   // Fivefold nested loops. The second input resets its position for each
191   // iteration of the second loop. The first input resets its position at the
192   // beginning of the fourth loop. The innermost loop is an elementwise Mul of
193   // sections of the arrays.
194   uint8* output_data_ptr = output_data;
195   const uint8* input1_data_ptr = input1_data;
196   const uint8* input2_data_reset = input2_data;
197   int y0 = params.broadcast_shape[0];
198   int y1 = params.broadcast_shape[1];
199   int y2 = params.broadcast_shape[2];
200   int y3 = params.broadcast_shape[3];
201   int y4 = params.broadcast_shape[4];
202   for (int i0 = 0; i0 < y0; ++i0) {
203     const uint8* input2_data_ptr;
204     for (int i1 = 0; i1 < y1; ++i1) {
205       input2_data_ptr = input2_data_reset;
206       for (int i2 = 0; i2 < y2; ++i2) {
207         for (int i3 = 0; i3 < y3; ++i3) {
208           MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
209                          output_data_ptr);
210           input2_data_ptr += y4;
211           output_data_ptr += y4;
212         }
213         input1_data_ptr += y4;
214       }
215     }
216     input2_data_reset = input2_data_ptr;
217   }
218 }
219 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)220 inline void Mul(const ArithmeticParams& params,
221                 const RuntimeShape& input1_shape, const int16* input1_data,
222                 const RuntimeShape& input2_shape, const int16* input2_data,
223                 const RuntimeShape& output_shape, int16* output_data) {
224   ruy::profiler::ScopeLabel label("Mul/Int16");
225 
226   const int flat_size =
227       MatchingElementsSize(input1_shape, input2_shape, output_shape);
228 
229   for (int i = 0; i < flat_size; i++) {
230     // F0 uses 0 integer bits, range [-1, 1].
231     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
232 
233     F0 unclamped_result =
234         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
235     output_data[i] = unclamped_result.raw();
236   }
237 }
238 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)239 inline void Mul(const ArithmeticParams& params,
240                 const RuntimeShape& input1_shape, const int16* input1_data,
241                 const RuntimeShape& input2_shape, const int16* input2_data,
242                 const RuntimeShape& output_shape, uint8* output_data) {
243   ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
244   int32 output_offset = params.output_offset;
245   int32 output_activation_min = params.quantized_activation_min;
246   int32 output_activation_max = params.quantized_activation_max;
247   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
248 
249   const int flat_size =
250       MatchingElementsSize(input1_shape, input2_shape, output_shape);
251 
252   for (int i = 0; i < flat_size; i++) {
253     // F0 uses 0 integer bits, range [-1, 1].
254     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
255 
256     F0 unclamped_result =
257         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
258     int16 rescaled_result =
259         gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
260     int16 clamped_result =
261         std::min<int16>(output_activation_max - output_offset, rescaled_result);
262     clamped_result =
263         std::max<int16>(output_activation_min - output_offset, clamped_result);
264     output_data[i] = output_offset + clamped_result;
265   }
266 }
267 
Sub16(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16_t * input1_data,const RuntimeShape & input2_shape,const int16_t * input2_data,const RuntimeShape & output_shape,int16_t * output_data)268 inline void Sub16(const ArithmeticParams& params,
269                   const RuntimeShape& input1_shape, const int16_t* input1_data,
270                   const RuntimeShape& input2_shape, const int16_t* input2_data,
271                   const RuntimeShape& output_shape, int16_t* output_data) {
272   ruy::profiler::ScopeLabel label("Sub/Int16");
273   const int input1_shift = params.input1_shift;
274   const int flat_size =
275       MatchingElementsSize(input1_shape, input2_shape, output_shape);
276   const int16 output_activation_min = params.quantized_activation_min;
277   const int16 output_activation_max = params.quantized_activation_max;
278 
279   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
280   TFLITE_DCHECK_LE(input1_shift, 0);
281   TFLITE_DCHECK_LE(params.input2_shift, 0);
282   const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
283   const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
284   const int input_right_shift =
285       input1_shift == 0 ? -params.input2_shift : -input1_shift;
286 
287   if (input1_shift == 0) {
288     // F0 uses 0 integer bits, range [-1, 1].
289     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
290     for (int i = 0; i < flat_size; ++i) {
291       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
292       F0 scaled_input = F0::FromRaw(
293           gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
294       F0 result = SaturatingSub(input_ready_scaled, scaled_input);
295       const int16 raw_output = result.raw();
296       const int16 clamped_output = std::min(
297           output_activation_max, std::max(output_activation_min, raw_output));
298       output_data[i] = clamped_output;
299     }
300   } else {
301     // F0 uses 0 integer bits, range [-1, 1].
302     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
303     for (int i = 0; i < flat_size; ++i) {
304       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
305       F0 scaled_input = F0::FromRaw(
306           gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
307       F0 result = SaturatingSub(scaled_input, input_ready_scaled);
308       const int16 raw_output = result.raw();
309       const int16 clamped_output = std::min(
310           output_activation_max, std::max(output_activation_min, raw_output));
311       output_data[i] = clamped_output;
312     }
313   }
314 }
315 
316 template <typename Scalar>
Pack(const PackParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)317 void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
318           const Scalar* const* input_data, const RuntimeShape& output_shape,
319           Scalar* output_data) {
320   ruy::profiler::ScopeLabel label("Pack");
321   const int dimensions = output_shape.DimensionsCount();
322   int axis = params.axis;
323   int inputs_count = params.inputs_count;
324 
325   int outer_size = 1;
326   for (int i = 0; i < axis; i++) {
327     outer_size *= output_shape.Dims(i);
328   }
329   int copy_size = 1;
330   for (int i = params.axis + 1; i < dimensions; i++) {
331     copy_size *= output_shape.Dims(i);
332   }
333   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
334 
335   for (int i = 0; i < inputs_count; ++i) {
336     for (int k = 0; k < outer_size; k++) {
337       const Scalar* input_ptr = input_data[i] + copy_size * k;
338       int loc = k * inputs_count * copy_size + i * copy_size;
339       memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
340     }
341   }
342 }
343 
344 template <typename Scalar>
Unpack(const UnpackParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * const * output_datas)345 void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
346             const Scalar* input_data, const RuntimeShape& output_shape,
347             Scalar* const* output_datas) {
348   ruy::profiler::ScopeLabel label("Unpack");
349   const int dimensions = input_shape.DimensionsCount();
350   const int outputs_count = params.num_split;
351 
352   int outer_size = 1;
353   int axis = params.axis;
354   if (axis < 0) {
355     axis += dimensions;
356   }
357   TFLITE_DCHECK_GE(axis, 0);
358   TFLITE_DCHECK_LT(axis, dimensions);
359   for (int i = 0; i < axis; ++i) {
360     outer_size *= input_shape.Dims(i);
361   }
362   int copy_size = 1;
363   for (int i = axis + 1; i < dimensions; ++i) {
364     copy_size *= input_shape.Dims(i);
365   }
366   TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
367 
368   for (int i = 0; i < outputs_count; ++i) {
369     for (int k = 0; k < outer_size; k++) {
370       Scalar* output_ptr = output_datas[i] + copy_size * k;
371       int loc = k * outputs_count * copy_size + i * copy_size;
372       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
373     }
374   }
375 }
376 
377 template <typename Scalar>
PackWithScaling(const PackParams & params,const RuntimeShape * const * input_shapes,const uint8 * const * input_data,const RuntimeShape & output_shape,uint8 * output_data)378 void PackWithScaling(const PackParams& params,
379                      const RuntimeShape* const* input_shapes,
380                      const uint8* const* input_data,
381                      const RuntimeShape& output_shape, uint8* output_data) {
382   ruy::profiler::ScopeLabel label("PackWithScaling");
383   const int dimensions = output_shape.DimensionsCount();
384   int axis = params.axis;
385   const int32* input_zeropoint = params.input_zeropoint;
386   const float* input_scale = params.input_scale;
387   int inputs_count = params.inputs_count;
388   const int32 output_zeropoint = params.output_zeropoint;
389   const float output_scale = params.output_scale;
390 
391   int outer_size = 1;
392   for (int i = 0; i < axis; i++) {
393     outer_size *= output_shape.Dims(i);
394   }
395   int copy_size = 1;
396   for (int i = axis + 1; i < dimensions; i++) {
397     copy_size *= output_shape.Dims(i);
398   }
399   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
400 
401   Scalar* output_ptr = output_data;
402   const float inverse_output_scale = 1.f / output_scale;
403   for (int k = 0; k < outer_size; k++) {
404     for (int i = 0; i < inputs_count; ++i) {
405       if (input_zeropoint[i] == output_zeropoint &&
406           input_scale[i] == output_scale) {
407         memcpy(output_ptr, input_data[i] + k * copy_size,
408                copy_size * sizeof(Scalar));
409       } else {
410         assert(false);
411         const float scale = input_scale[i] * inverse_output_scale;
412         const float bias = -input_zeropoint[i] * scale;
413         auto input_ptr = input_data[i];
414         for (int j = 0; j < copy_size; ++j) {
415           const int32_t value =
416               static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) +
417               output_zeropoint;
418           output_ptr[j] =
419               static_cast<uint8_t>(std::max(std::min(255, value), 0));
420         }
421       }
422       output_ptr += copy_size;
423     }
424   }
425 }
426 
427 template <typename Scalar>
DepthConcatenation(const ConcatenationParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)428 void DepthConcatenation(const ConcatenationParams& params,
429                         const RuntimeShape* const* input_shapes,
430                         const Scalar* const* input_data,
431                         const RuntimeShape& output_shape, Scalar* output_data) {
432   ruy::profiler::ScopeLabel label("DepthConcatenation");
433   auto params_copy = params;
434   params_copy.axis = 3;
435   Concatenation(params_copy, input_shapes, input_data, output_shape,
436                 output_data);
437 }
438 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)439 inline void LstmCell(
440     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
441     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
442     const float* prev_activ_data, const RuntimeShape& weights_shape,
443     const float* weights_data, const RuntimeShape& unextended_bias_shape,
444     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
445     const float* prev_state_data,
446     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
447     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
448     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
449     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
450   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
451   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
452   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
453   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
454   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
455   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
456   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
457   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
458   const RuntimeShape input_shape =
459       RuntimeShape::ExtendedShape(4, unextended_input_shape);
460   const RuntimeShape prev_activ_shape =
461       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
462   const RuntimeShape bias_shape =
463       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
464   const RuntimeShape prev_state_shape =
465       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
466   const RuntimeShape output_state_shape =
467       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
468   const RuntimeShape output_activ_shape =
469       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
470   const RuntimeShape concat_temp_shape =
471       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
472   const RuntimeShape activ_temp_shape =
473       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
474   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
475 
476   const int weights_dim_count = weights_shape.DimensionsCount();
477   const int batches =
478       MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
479                   output_state_shape, 0, output_activ_shape, 0);
480   const int height =
481       MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
482                   output_state_shape, 1, output_activ_shape, 1);
483   const int width =
484       MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
485                   output_state_shape, 2, output_activ_shape, 2);
486   const int input_depth = input_shape.Dims(3);
487   const int prev_activ_depth = prev_activ_shape.Dims(3);
488   const int total_input_depth = prev_activ_depth + input_depth;
489   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
490                    total_input_depth);
491   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
492   const int intern_activ_depth =
493       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
494   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
495                    intern_activ_depth * total_input_depth);
496   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
497   const int output_depth =
498       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
499                   3, output_activ_shape, 3);
500   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
501 
502   // Concatenate prev_activ and input data together
503   std::vector<float const*> concat_input_arrays_data;
504   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
505   concat_input_arrays_data.push_back(input_data);
506   concat_input_arrays_data.push_back(prev_activ_data);
507   concat_input_arrays_shapes.push_back(&input_shape);
508   concat_input_arrays_shapes.push_back(&prev_activ_shape);
509   tflite::ConcatenationParams concat_params;
510   concat_params.axis = 3;
511   concat_params.inputs_count = concat_input_arrays_data.size();
512   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
513                 &(concat_input_arrays_data[0]), concat_temp_shape,
514                 concat_temp_data);
515 
516   // Fully connected
517   tflite::FullyConnectedParams fc_params;
518   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
519   fc_params.float_activation_max = std::numeric_limits<float>::max();
520   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
521                  weights_data, bias_shape, bias_data, activ_temp_shape,
522                  activ_temp_data);
523 
524   // Memory state update (the LSTM "guts")
525   for (int b = 0; b < batches; ++b) {
526     for (int w = 0; w < width; ++w) {
527       for (int h = 0; h < height; ++h) {
528         for (int c = 0; c < output_depth; ++c) {
529           const float input_gate =
530               1.f /
531               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
532                                                       0 * output_depth + c)]));
533           const float new_input = std::tanh(activ_temp_data[Offset(
534               activ_temp_shape, b, h, w, 1 * output_depth + c)]);
535           const float forget_gate =
536               1.f /
537               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
538                                                       2 * output_depth + c)]));
539           const float output_gate =
540               1.f /
541               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
542                                                       3 * output_depth + c)]));
543           const float new_state =
544               input_gate * new_input +
545               forget_gate *
546                   prev_state_data[Offset(prev_state_shape, b, h, w, c)];
547           output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
548           output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
549               output_gate * std::tanh(new_state);
550         }
551       }
552     }
553   }
554 }
555 
556 // Quantized LSTM cell implementation.
557 // The quantization of the input, output arrays is as follows:
558 //  - The input activations are quantized as uint8 on the interval
559 //    [-1, 127/128].
560 //    The rationale for that is that is the natural interval for output
561 //    activations (see next point) and these need to be concatenated together.
562 //    We could accommodate different ranges by re-scaling, but we empirically
563 //    found that setting the input activations range to be [-1, 127/128] in the
564 //    first place, removing the need for re-scaling, greatly improves accuracy.
565 //  - The output activations are quantized as uint8 on the interval
566 //    [-1, 127/128].
567 //    The rationale for that is that the definition of a LSTM cell makes them
568 //    intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
569 //    makes for simpler, more accurate fixed-point arithmetic.
570 //  - The output-at-previous-timestep state array is obviously quantized as
571 //    the output activations.
572 //  - The internal LSTM memory (not the output-at-previous-timestep, the other
573 //    internal state array) is int16-quantized and may use any power-of-two,
574 //    symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
575 //    StateIntegerBits below, see the below discussion of that template
576 //    parameter ("The StateIntegerBits template parameter").
577 //  - The output of the internal fully-connected node is int16-quantized
578 //    on the interval [-8, 8 * 32767/32768], the rationale for which is
579 //    explained just below ("Why [-8, 8] for fully-connected output?").
580 //
581 //
582 // === The StateIntegerBits template parameter ===
583 //
584 // The StateIntegerBits template parameter controls the fixed-point format used
585 // to represent the internal memory of the LSTM cell (not the
586 // output-at-previous-timestep, the other internal state array). It's currently
587 // a template parameter so that the model can control that. The most typical
588 // value for StateIntegerBits is 4. Other plausible values are anywhere between
589 // 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
590 // and drop that template parameter. The reason why it can't be a runtime
591 // parameter is that this controls the fixed-point format used, i.e. we need to
592 // generate actually different code based on it. In particular, we generate code
593 // for a fixed-point tanh() implementation for that format, which internally
594 // uses a fixed-point exp() implementation, which internally uses a
595 // barrel-shifter with a number of steps that depends on StateIntegerBits.
596 // Another consequence of that is that a higher value of StateIntegerBits
597 // results in a more expensive implementation (more barrel shifter steps
598 // needed).
599 //
600 //
601 // === Why [-8, 8] for fully-connected output? ===
602 //
603 // This array is only fed to Logistic and Tanh functions, for which
604 // the quantized implementation will want to use fixed-point arithmetic,
605 // requiring a power-of-two representation interval. Thus, we should right
606 // away quantize this array to a power-of-two interval; otherwise,
607 // implementation will need to rescale that, losing any benefit that a tighter
608 // representation interval might otherwise yield, while introducing some
609 // numerical error and computational overhead.
610 //
611 // Now, Logistic and Tanh
612 // are nearly constant (nearly equal to their horizontal asymptotes)
613 // outside of a small bounded interval around 0:
614 //
615 //   Logistic(4) = 1 - 1.8e-2     Tanh(4) = 1 - 6.7e-4
616 //   Logistic(8) = 1 - 3.4e-4     Tanh(8) = 1 - 2.3e-7
617 //   Logistic(16) = 1 - 1.1e-7    Tanh(16) = 1 - 2.5e-14
618 //
619 // From this, we see that clamping to [-4, 4] would be too inaccurate
620 // (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
621 // while clamping to [-16, 16] would make no difference even in float32.
622 // However, for a fixed-point implementation in 16-bit integers, using 5
623 // integer bits to represent the [-16, 16] range would leave only 11
624 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
625 // representable values. Notice that is higher than the
626 // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
627 // Using [-8, 8] thus seems like the better compromise overall, enjoying
628 // an increment of 2.4e-4 between representable values and a worst-case
629 // clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
630 // [-16, 16].
631 //
632 // Moreover, all other things being equal, it is nice to choose the narrower
633 // representation range, as that makes the implementation of fixed-point
634 // math functions a little cheaper (each integer bit requires an additional
635 // barrel-shifter atep in the implementation of exp(-x)). That is further
636 // reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
637 // sense for 32-bit float or 32-bit fixed-point quantization, but we are
638 // aiming for 16-bit fixed-point quantization of these internal nodes here.
639 //
640 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,void * gemmlowp_context)641 inline void LstmCell(const LstmCellParams& params,
642                      const RuntimeShape& unextended_input_shape,
643                      const uint8* input_data_uint8,
644                      const RuntimeShape& unextended_prev_activ_shape,
645                      const uint8* prev_activ_data_uint8,
646                      const RuntimeShape& weights_shape,
647                      const uint8* weights_data_uint8,
648                      const RuntimeShape& unextended_bias_shape,
649                      const int32* bias_data_int32,
650                      const RuntimeShape& unextended_prev_state_shape,
651                      const int16* prev_state_data_int16,
652                      const RuntimeShape& unextended_output_state_shape,
653                      int16* output_state_data_int16,
654                      const RuntimeShape& unextended_output_activ_shape,
655                      uint8* output_activ_data_uint8,
656                      const RuntimeShape& unextended_concat_temp_shape,
657                      uint8* concat_temp_data_uint8,
658                      const RuntimeShape& unextended_activ_temp_shape,
659                      int16* activ_temp_data_int16, void* gemmlowp_context) {
660   (void)gemmlowp_context;  // only used in optimized code.
661   int32 weights_zero_point = params.weights_zero_point;
662   int32 accum_multiplier = params.accum_multiplier;
663   int accum_shift = params.accum_shift;
664   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
665   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
666   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
667   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
668   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
669   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
670   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
671   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
672   const RuntimeShape input_shape =
673       RuntimeShape::ExtendedShape(4, unextended_input_shape);
674   const RuntimeShape prev_activ_shape =
675       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
676   const RuntimeShape bias_shape =
677       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
678   const RuntimeShape prev_state_shape =
679       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
680   const RuntimeShape output_state_shape =
681       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
682   const RuntimeShape output_activ_shape =
683       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
684   const RuntimeShape concat_temp_shape =
685       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
686   const RuntimeShape activ_temp_shape =
687       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
688   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
689 
690   // Gather dimensions information, and perform consistency checks.
691   const int weights_dim_count = weights_shape.DimensionsCount();
692   const int outer_size = MatchingFlatSizeSkipDim(
693       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
694       output_activ_shape);
695   const int input_depth = input_shape.Dims(3);
696   const int prev_activ_depth = prev_activ_shape.Dims(3);
697   const int total_input_depth = prev_activ_depth + input_depth;
698   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
699                    total_input_depth);
700   const int intern_activ_depth =
701       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
702   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
703                    intern_activ_depth * total_input_depth);
704   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
705   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
706   const int output_depth =
707       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
708                   3, output_activ_shape, 3);
709   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
710   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
711   const int fc_output_depth =
712       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
713   const int fc_accum_depth = total_input_depth;
714   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
715 
716   // Depth-concatenate prev_activ and input data together.
717   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
718                                               prev_activ_data_uint8};
719   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
720                                                        &prev_activ_shape};
721   tflite::ConcatenationParams concat_params;
722   concat_params.axis = 3;
723   concat_params.inputs_count = 2;
724   Concatenation(concat_params, concat_input_arrays_shapes,
725                 concat_input_arrays_data, concat_temp_shape,
726                 concat_temp_data_uint8);
727 
728   // Implementation of the fully connected node inside the LSTM cell.
729   // The operands are 8-bit integers, the accumulators are internally 32bit
730   // integers, and the output is 16-bit fixed-point with 3 integer bits so
731   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
732   // is explained in the function comment above.
733   for (int b = 0; b < fc_batches; ++b) {
734     for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
735       // Internal accumulation.
736       // Initialize accumulator with the bias-value.
737       int32 accum = bias_data_int32[out_c];
738       // Accumulation loop.
739       for (int d = 0; d < fc_accum_depth; ++d) {
740         int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
741         int16 weights_val =
742             weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
743         accum += input_val * weights_val;
744       }
745       // Down-scale the final int32 accumulator to the scale used by our
746       // (16-bit, using 3 integer bits) fixed-point format. The quantized
747       // multiplier and shift here have been pre-computed offline
748       // (e.g. by toco).
749       accum =
750           MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
751       // Saturate, cast to int16, and store to the temporary activations array.
752       accum = std::max(-32768, std::min(32767, accum));
753       activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
754     }
755   }
756 
757   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
758   // and muls, all done in 16-bit fixed-point.
759   for (int b = 0; b < outer_size; ++b) {
760     for (int c = 0; c < output_depth; ++c) {
761       // Define the fixed-point data types that we will use here. All use
762       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
763       // They only differ by the number of integral vs. fractional bits,
764       // determining the range of values that they can represent.
765       //
766       // F0 uses 0 integer bits, range [-1, 1].
767       // This is the return type of math functions such as tanh, logistic,
768       // whose range is in [-1, 1].
769       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
770       // F3 uses 3 integer bits, range [-8, 8].
771       // This is the range of the previous fully-connected node's output,
772       // which is our input here.
773       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
774       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
775       // 2^StateIntegerBits]. It's used to represent the internal state, whose
776       // number of integer bits is currently dictated by the model. See comment
777       // on the StateIntegerBits template parameter above.
778       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
779       // Implementation of input gate, using fixed-point logistic function.
780       F3 input_gate_input = F3::FromRaw(
781           activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
782       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
783       // Implementation of input modulation gate, using fixed-point tanh
784       // function.
785       F3 input_modulation_gate_input = F3::FromRaw(
786           activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
787       F0 input_modulation_gate_output =
788           gemmlowp::tanh(input_modulation_gate_input);
789       // Implementation of forget gate, using fixed-point logistic function.
790       F3 forget_gate_input = F3::FromRaw(
791           activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
792       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
793       // Implementation of output gate, using fixed-point logistic function.
794       F3 output_gate_input = F3::FromRaw(
795           activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
796       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
797       // Implementation of internal multiplication nodes, still in fixed-point.
798       F0 input_times_input_modulation =
799           input_gate_output * input_modulation_gate_output;
800       FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
801       FS prev_state_times_forget_state = forget_gate_output * prev_state;
802       // Implementation of internal addition node, saturating.
803       FS new_state = gemmlowp::SaturatingAdd(
804           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
805           prev_state_times_forget_state);
806       // Implementation of last internal Tanh node, still in fixed-point.
807       // Since a Tanh fixed-point implementation is specialized for a given
808       // number or integer bits, and each specialization can have a substantial
809       // code size, and we already used above a Tanh on an input with 3 integer
810       // bits, and per the table in the above function comment there is no
811       // significant accuracy to be lost by clamping to [-8, +8] for a
812       // 3-integer-bits representation, let us just do that. This helps people
813       // porting this to targets where code footprint must be minimized.
814       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
815       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
816       // Store the new internal state back to memory, as 16-bit integers.
817       // Note: here we store the original value with StateIntegerBits, not
818       // the rescaled 3-integer-bits value fed to tanh.
819       output_state_data_int16[b * output_depth + c] = new_state.raw();
820       // Down-scale the output activations to 8-bit integers, saturating,
821       // and store back to memory.
822       int16 rescaled_output_activ =
823           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
824       int16 clamped_output_activ =
825           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
826       output_activ_data_uint8[b * output_depth + c] =
827           128 + clamped_output_activ;
828     }
829   }
830 }
831 
832 template <typename Scalar>
Split(const SplitParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape * const * output_shapes,Scalar * const * output_data)833 void Split(const SplitParams& params, const RuntimeShape& input_shape,
834            const Scalar* input_data, const RuntimeShape* const* output_shapes,
835            Scalar* const* output_data) {
836   ruy::profiler::ScopeLabel label("Split");
837   const int split_dimensions = input_shape.DimensionsCount();
838   int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
839   int outputs_count = params.num_split;
840   TFLITE_DCHECK_LT(axis, split_dimensions);
841 
842   int64_t split_size = 0;
843   for (int i = 0; i < outputs_count; i++) {
844     TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
845     for (int j = 0; j < split_dimensions; j++) {
846       if (j != axis) {
847         MatchingDim(*output_shapes[i], j, input_shape, j);
848       }
849     }
850     split_size += output_shapes[i]->Dims(axis);
851   }
852   TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
853   int64_t outer_size = 1;
854   for (int i = 0; i < axis; ++i) {
855     outer_size *= input_shape.Dims(i);
856   }
857   // For all output arrays,
858   // FlatSize() = outer_size * Dims(axis) * base_inner_size;
859   int64_t base_inner_size = 1;
860   for (int i = axis + 1; i < split_dimensions; ++i) {
861     base_inner_size *= input_shape.Dims(i);
862   }
863 
864   const Scalar* input_ptr = input_data;
865   for (int k = 0; k < outer_size; k++) {
866     for (int i = 0; i < outputs_count; ++i) {
867       const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
868       memcpy(output_data[i] + k * copy_size, input_ptr,
869              copy_size * sizeof(Scalar));
870       input_ptr += copy_size;
871     }
872   }
873 }
874 
NodeOffset(int b,int h,int w,int height,int width)875 inline int NodeOffset(int b, int h, int w, int height, int width) {
876   return (b * height + h) * width + w;
877 }
878 
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)879 inline void LocalResponseNormalization(
880     const tflite::LocalResponseNormalizationParams& op_params,
881     const RuntimeShape& input_shape, const float* input_data,
882     const RuntimeShape& output_shape, float* output_data) {
883   const int trailing_dim = input_shape.DimensionsCount() - 1;
884   const int outer_size =
885       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
886   const int depth =
887       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
888 
889   for (int i = 0; i < outer_size; ++i) {
890     for (int c = 0; c < depth; ++c) {
891       const int begin_input_c = std::max(0, c - op_params.range);
892       const int end_input_c = std::min(depth, c + op_params.range);
893       float accum = 0.f;
894       for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
895         const float input_val = input_data[i * depth + input_c];
896         accum += input_val * input_val;
897       }
898       const float multiplier =
899           std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
900       output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
901     }
902   }
903 }
904 
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)905 inline void LogSoftmax(const SoftmaxParams& params,
906                        const RuntimeShape& input_shape, const float* input_data,
907                        const RuntimeShape& output_shape, float* output_data) {
908   const int trailing_dim = input_shape.DimensionsCount() - 1;
909   const int outer_size =
910       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
911   const int depth =
912       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
913 
914   for (int i = 0; i < outer_size; ++i) {
915     // Find max element value which we'll use to ensure numerical stability
916     // taking advantage of the following equality:
917     // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
918     float max = std::numeric_limits<float>::lowest();
919     for (int c = 0; c < depth; ++c) {
920       max = std::max(max, input_data[i * depth + c]);
921     }
922 
923     // Compute sum.
924     float sum = 0.f;
925     for (int c = 0; c < depth; ++c) {
926       sum += std::exp(input_data[i * depth + c] - max);
927     }
928 
929     // Compute result.
930     const float log_sum = std::log(sum);
931     for (int c = 0; c < depth; ++c) {
932       output_data[i * depth + c] = input_data[i * depth + c] - max - log_sum;
933     }
934   }
935 }
936 
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)937 inline void LogSoftmax(const SoftmaxParams& params,
938                        const RuntimeShape& input_shape, const uint8* input_data,
939                        const RuntimeShape& output_shape, uint8* output_data) {
940   ruy::profiler::ScopeLabel label("LogSoftmax/8bit");
941   const int32 input_multiplier = params.input_multiplier;
942   const int32 input_left_shift = params.input_left_shift;
943   const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
944   const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
945   const int diff_min = params.diff_min;
946   // The representation chosen for the input to the exp() function is Q5.26.
947   // We need to leave extra space since values that we skip might be as large
948   // as -32 before multiplying by input_beta_multiplier, and therefore as
949   // large as -16 afterwards.  Note that exp(-8) is definitely not
950   // insignificant to accumulation, but exp(-16) definitely is.
951   static constexpr int kScaledDiffIntegerBits = 5;
952   static constexpr int kAccumulationIntegerBits = 12;
953   static constexpr int kOutputIntegerBits = 4;
954   using FixedPointScaledDiff =
955       gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
956   using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
957 
958   const int trailing_dim = input_shape.DimensionsCount() - 1;
959   const int outer_size =
960       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
961   const int depth =
962       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
963 
964   for (int i = 0; i < outer_size; ++i) {
965     uint8 max_in_row = 0;
966     for (int c = 0; c < depth; ++c) {
967       max_in_row = std::max(max_in_row, input_data[i * depth + c]);
968     }
969 
970     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
971     for (int c = 0; c < depth; ++c) {
972       int32 input_diff =
973           static_cast<int32>(input_data[i * depth + c]) - max_in_row;
974       if (input_diff >= diff_min) {
975         const int32 input_diff_rescaled =
976             MultiplyByQuantizedMultiplierGreaterThanOne(
977                 input_diff, input_multiplier, input_left_shift);
978         const FixedPointScaledDiff scaled_diff_f8 =
979             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
980         sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
981                                         exp_on_negative_values(scaled_diff_f8));
982       }
983     }
984 
985     const int32 fixed_log_sum_of_exps =
986         log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
987             sum_of_exps)
988             .raw();
989 
990     // rescaled_diff_min is smallest representable in
991     // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
992     // log-sub-exps that will be subtracted in the loop.
993     //
994     // The thresholds diff_min, etc are negative.
995     const int rescaled_diff_min =
996         fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
997     const int adjusted_diff_min =
998         std::max(diff_min - 1,  // Note use of > below instead of >= above.
999                  MultiplyByQuantizedMultiplierSmallerThanOneExp(
1000                      rescaled_diff_min, reverse_scaling_divisor,
1001                      -reverse_scaling_right_shift));
1002 
1003     for (int c = 0; c < depth; ++c) {
1004       int32 input_diff =
1005           static_cast<int32>(input_data[i * depth + c]) - max_in_row;
1006       if (input_diff > adjusted_diff_min) {
1007         const int32 input_diff_rescaled =
1008             MultiplyByQuantizedMultiplierGreaterThanOne(
1009                 input_diff, input_multiplier, input_left_shift);
1010         int32 unsat_output =
1011             gemmlowp::RoundingDivideByPOT(
1012                 (input_diff_rescaled - fixed_log_sum_of_exps),
1013                 31 - kScaledDiffIntegerBits - kOutputIntegerBits) +
1014             255;
1015 
1016         output_data[i * depth + c] = static_cast<uint8>(
1017             std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
1018       } else {
1019         // Set output to smallest value.
1020         output_data[i * depth + c] = 0;
1021       }
1022     }
1023   }
1024 }
1025 
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)1026 inline void Dequantize(const RuntimeShape& input_shape,
1027                        const Eigen::half* input_data,
1028                        const RuntimeShape& output_shape, float* output_data) {
1029   const int flat_size = MatchingFlatSize(input_shape, output_shape);
1030   for (int i = 0; i < flat_size; i++) {
1031     output_data[i] = Eigen::half_impl::half_to_float(input_data[i]);
1032   }
1033 }
1034 
FakeQuant(const tflite::FakeQuantParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1035 inline void FakeQuant(const tflite::FakeQuantParams& op_params,
1036                       const RuntimeShape& input_shape, const float* input_data,
1037                       const RuntimeShape& output_shape, float* output_data) {
1038   ruy::profiler::ScopeLabel label("FakeQuant");
1039   float rmin = op_params.minmax.min;
1040   float rmax = op_params.minmax.max;
1041   int num_bits = op_params.num_bits;
1042   // 0 should always be a representable value. Let's assume that the initial
1043   // min,max range contains 0.
1044   TFLITE_DCHECK_LE(rmin, 0.0f);
1045   TFLITE_DCHECK_GE(rmax, 0.0f);
1046   TFLITE_DCHECK_LT(rmin, rmax);
1047 
1048   // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
1049   int quant_min = 0;
1050   int quant_max = (1 << num_bits) - 1;
1051   float nudged_min, nudged_max, nudged_scale;
1052   NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
1053                          &nudged_max, &nudged_scale);
1054   const int flat_size = MatchingFlatSize(input_shape, output_shape);
1055   FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
1056                     output_data, flat_size);
1057 }
1058 
1059 // Common subroutine for both `GatherNd` and `GatherNdString`.
1060 struct GatherNdHelperResult {
1061   int n_slices;
1062   int slice_size;
1063   int indices_nd;
1064   std::vector<int> dims_to_count;
1065 };
1066 
1067 // Returns common values being used on both `GatherNd` and `GatherNdString`.
GatherNdHelper(const RuntimeShape & params_shape,const RuntimeShape & indices_shape)1068 inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape,
1069                                            const RuntimeShape& indices_shape) {
1070   GatherNdHelperResult ret;
1071   ret.n_slices = 1;
1072   ret.slice_size = 1;
1073   const int indices_dims = indices_shape.DimensionsCount();
1074   ret.indices_nd = indices_shape.Dims(indices_dims - 1);
1075   const int params_dims = params_shape.DimensionsCount();
1076   for (int i = 0; i < indices_dims - 1; ++i) {
1077     ret.n_slices *= indices_shape.Dims(i);
1078   }
1079   for (int i = ret.indices_nd; i < params_dims; ++i) {
1080     ret.slice_size *= params_shape.Dims(i);
1081   }
1082 
1083   int remain_flat_size = params_shape.FlatSize();
1084   ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
1085   for (int i = 0; i < ret.indices_nd; ++i) {
1086     ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
1087     remain_flat_size = ret.dims_to_count[i];
1088   }
1089 
1090   return ret;
1091 }
1092 
1093 template <typename ParamsT, typename IndicesT = int32>
GatherNd(const RuntimeShape & params_shape,const ParamsT * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,ParamsT * output_data)1094 inline void GatherNd(const RuntimeShape& params_shape,
1095                      const ParamsT* params_data,
1096                      const RuntimeShape& indices_shape,
1097                      const IndicesT* indices_data,
1098                      const RuntimeShape& output_shape, ParamsT* output_data) {
1099   ruy::profiler::ScopeLabel label("GatherNd");
1100 
1101   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1102   for (int i = 0; i < res.n_slices; ++i) {
1103     int from_pos = 0;
1104     for (int j = 0; j < res.indices_nd; ++j) {
1105       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1106     }
1107     std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
1108                 sizeof(ParamsT) * res.slice_size);
1109   }
1110 }
1111 
1112 #ifndef TF_LITE_STATIC_MEMORY
1113 template <typename IndicesT = int32>
GatherNdString(const RuntimeShape & params_shape,const TfLiteTensor * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,TfLiteTensor * output_data)1114 inline void GatherNdString(const RuntimeShape& params_shape,
1115                            const TfLiteTensor* params_data,
1116                            const RuntimeShape& indices_shape,
1117                            const IndicesT* indices_data,
1118                            const RuntimeShape& output_shape,
1119                            TfLiteTensor* output_data) {
1120   ruy::profiler::ScopeLabel label("GatherNdString");
1121 
1122   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1123   DynamicBuffer buffer;
1124   for (int i = 0; i < res.n_slices; ++i) {
1125     int from_pos = 0;
1126     for (int j = 0; j < res.indices_nd; ++j) {
1127       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1128     }
1129     for (int j = 0; j < res.slice_size; ++j) {
1130       buffer.AddString(GetString(params_data, from_pos + j));
1131     }
1132   }
1133   buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
1134 }
1135 #endif
1136 
1137 template <typename IndicesT, typename UpdatesT>
ScatterNd(const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & updates_shape,const UpdatesT * updates_data,const RuntimeShape & output_shape,UpdatesT * output_data)1138 inline void ScatterNd(const RuntimeShape& indices_shape,
1139                       const IndicesT* indices_data,
1140                       const RuntimeShape& updates_shape,
1141                       const UpdatesT* updates_data,
1142                       const RuntimeShape& output_shape, UpdatesT* output_data) {
1143   ruy::profiler::ScopeLabel label("ScatterNd");
1144 
1145   int n_slices = 1;
1146   int slice_size = 1;
1147   const int outer_dims = indices_shape.DimensionsCount() - 1;
1148   const int indices_nd = indices_shape.Dims(outer_dims);
1149   const int updates_dims = updates_shape.DimensionsCount();
1150   for (int i = 0; i < outer_dims; ++i) {
1151     n_slices *= indices_shape.Dims(i);
1152   }
1153   for (int i = outer_dims; i < updates_dims; ++i) {
1154     slice_size *= updates_shape.Dims(i);
1155   }
1156 
1157   int output_flat_size = output_shape.FlatSize();
1158   int remain_flat_size = output_flat_size;
1159   std::vector<int> dims_to_count(indices_nd, 0);
1160   for (int i = 0; i < indices_nd; ++i) {
1161     dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
1162     remain_flat_size = dims_to_count[i];
1163   }
1164 
1165   memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
1166   for (int i = 0; i < n_slices; ++i) {
1167     int to_pos = 0;
1168     for (int j = 0; j < indices_nd; ++j) {
1169       IndicesT idx = indices_data[i * indices_nd + j];
1170       TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
1171       to_pos += idx * dims_to_count[j];
1172     }
1173     for (int j = 0; j < slice_size; j++) {
1174       output_data[to_pos + j] += updates_data[i * slice_size + j];
1175     }
1176   }
1177 }
1178 
ComputeInterpolationValues(const float value,const float scale,const bool half_pixel_centers,int32 input_size,float * scaled_value,int32 * lower_bound,int32 * upper_bound)1179 inline void ComputeInterpolationValues(const float value, const float scale,
1180                                        const bool half_pixel_centers,
1181                                        int32 input_size, float* scaled_value,
1182                                        int32* lower_bound, int32* upper_bound) {
1183   if (half_pixel_centers) {
1184     *scaled_value = (value + 0.5f) * scale - 0.5f;
1185   } else {
1186     *scaled_value = value * scale;
1187   }
1188   float scaled_value_floor = std::floor(*scaled_value);
1189   *lower_bound =
1190       std::max(static_cast<int32>(scaled_value_floor), static_cast<int32>(0));
1191   *upper_bound =
1192       std::min(static_cast<int32>(std::ceil(*scaled_value)), input_size - 1);
1193 }
1194 
1195 template <typename T>
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)1196 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1197                            const RuntimeShape& unextended_input_shape,
1198                            const T* input_data,
1199                            const RuntimeShape& unextended_output_size_shape,
1200                            const int32* output_size_data,
1201                            const RuntimeShape& unextended_output_shape,
1202                            T* output_data) {
1203   // If half_pixel_centers is True, align_corners must be False.
1204   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1205   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1206   TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
1207   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1208   const RuntimeShape input_shape =
1209       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1210   const RuntimeShape output_size_shape =
1211       RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
1212   const RuntimeShape output_shape =
1213       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1214 
1215   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1216   int32 input_height = input_shape.Dims(1);
1217   int32 input_width = input_shape.Dims(2);
1218   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1219 
1220   TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
1221   TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
1222   TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
1223   TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
1224   int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
1225   int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
1226 
1227   float height_scale = static_cast<float>(input_height) / output_height;
1228   float width_scale = static_cast<float>(input_width) / output_width;
1229   if (op_params.align_corners && output_height > 1) {
1230     height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
1231   }
1232   if (op_params.align_corners && output_width > 1) {
1233     width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
1234   }
1235 
1236   for (int b = 0; b < batches; ++b) {
1237     for (int y = 0; y < output_height; ++y) {
1238       float input_y;
1239       int32 y0, y1;
1240       ComputeInterpolationValues(y, height_scale, op_params.half_pixel_centers,
1241                                  input_height, &input_y, &y0, &y1);
1242       for (int x = 0; x < output_width; ++x) {
1243         float input_x;
1244         int32 x0, x1;
1245         ComputeInterpolationValues(x, width_scale, op_params.half_pixel_centers,
1246                                    input_width, &input_x, &x0, &x1);
1247         for (int c = 0; c < depth; ++c) {
1248           T interpolation =
1249               static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
1250                                  (1 - (input_y - y0)) * (1 - (input_x - x0)) +
1251                              input_data[Offset(input_shape, b, y1, x0, c)] *
1252                                  (input_y - y0) * (1 - (input_x - x0)) +
1253                              input_data[Offset(input_shape, b, y0, x1, c)] *
1254                                  (1 - (input_y - y0)) * (input_x - x0) +
1255                              input_data[Offset(input_shape, b, y1, x1, c)] *
1256                                  (input_y - y0) * (input_x - x0));
1257           output_data[Offset(output_shape, b, y, x, c)] = interpolation;
1258         }
1259       }
1260     }
1261   }
1262 }
1263 
ComputeInterpolationValues(const int32 value,const int32 scale_10,const bool half_pixel_centers,int32 input_size,int32 * scaled_value,int32 * lower_bound,int32 * upper_bound)1264 inline void ComputeInterpolationValues(const int32 value, const int32 scale_10,
1265                                        const bool half_pixel_centers,
1266                                        int32 input_size, int32* scaled_value,
1267                                        int32* lower_bound, int32* upper_bound) {
1268   if (half_pixel_centers) {
1269     *scaled_value = value * scale_10 + scale_10 / 2 - (1 << 9);
1270   } else {
1271     *scaled_value = value * scale_10;
1272   }
1273   *lower_bound = std::max(*scaled_value / (1 << 10), 0);
1274   *upper_bound =
1275       std::min((*scaled_value + (1 << 10) - 1) / (1 << 10), input_size - 1);
1276 }
1277 
1278 // Same as above but doesn't use any floating-point for the resize
1279 template <typename T>
ResizeBilinearInteger(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)1280 inline void ResizeBilinearInteger(
1281     const tflite::ResizeBilinearParams& op_params,
1282     const RuntimeShape& unextended_input_shape, const T* input_data,
1283     const RuntimeShape& unextended_output_size_shape,
1284     const int32* output_size_data, const RuntimeShape& unextended_output_shape,
1285     T* output_data) {
1286   // If half_pixel_centers is True, align_corners must be False.
1287   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1288   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1289   TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
1290   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1291   const RuntimeShape input_shape =
1292       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1293   const RuntimeShape output_size_shape =
1294       RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
1295   const RuntimeShape output_shape =
1296       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1297 
1298   const int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1299   const int32 input_height = input_shape.Dims(1);
1300   const int32 input_width = input_shape.Dims(2);
1301   const int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1302 
1303   TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
1304   TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
1305   TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
1306   TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
1307   const int32 output_height =
1308       output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
1309   const int32 output_width =
1310       output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
1311 
1312   int32 height_scale_10 =
1313       ((1 << 10) * input_height + output_height / 2) / output_height;
1314   int32 width_scale_10 =
1315       ((1 << 10) * input_width + output_width / 2) / output_width;
1316   if (op_params.align_corners && output_height > 1) {
1317     height_scale_10 =
1318         ((1 << 10) * (input_height - 1) + (output_height - 1) / 2) /
1319         (output_height - 1);
1320   }
1321   if (op_params.align_corners && output_width > 1) {
1322     width_scale_10 = ((1 << 10) * (input_width - 1) + (output_width - 1) / 2) /
1323                      (output_width - 1);
1324   }
1325 
1326   for (int b = 0; b < batches; ++b) {
1327     for (int y = 0; y < output_height; ++y) {
1328       int32 input_y, y0, y1;
1329       ComputeInterpolationValues(y, height_scale_10,
1330                                  op_params.half_pixel_centers, input_height,
1331                                  &input_y, &y0, &y1);
1332       for (int x = 0; x < output_width; ++x) {
1333         int32 input_x, x0, x1;
1334         ComputeInterpolationValues(x, width_scale_10,
1335                                    op_params.half_pixel_centers, input_width,
1336                                    &input_x, &x0, &x1);
1337         for (int c = 0; c < depth; ++c) {
1338           const int64_t output_20_ll =
1339               static_cast<int64_t>(
1340                   input_data[Offset(input_shape, b, y0, x0, c)]) *
1341               ((1 << 10) - (input_y - (1 << 10) * y0)) *
1342               ((1 << 10) - (input_x - (1 << 10) * x0));
1343           const int64_t output_20_lu =
1344               static_cast<int64_t>(
1345                   input_data[Offset(input_shape, b, y1, x0, c)]) *
1346               (input_y - (1 << 10) * y0) *
1347               ((1 << 10) - (input_x - (1 << 10) * x0));
1348           const int64_t output_20_rl =
1349               static_cast<int64_t>(
1350                   input_data[Offset(input_shape, b, y0, x1, c)]) *
1351               ((1 << 10) - (input_y - (1 << 10) * y0)) *
1352               (input_x - (1 << 10) * x0);
1353           const int64_t output_20_ru =
1354               static_cast<int64_t>(
1355                   input_data[Offset(input_shape, b, y1, x1, c)]) *
1356               (input_y - (1 << 10) * y0) * (input_x - (1 << 10) * x0);
1357           const int64_t output_20 =
1358               output_20_ll + output_20_lu + output_20_rl + output_20_ru;
1359           const int64_t round = (output_20 > 0) ? (1 << 19) : -(1 << 19);
1360           const T interpolation =
1361               static_cast<T>((output_20 + round) / (1 << 20));
1362           output_data[Offset(output_shape, b, y, x, c)] = interpolation;
1363         }
1364       }
1365     }
1366   }
1367 }
1368 
1369 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)1370 inline void Slice(const tflite::SliceParams& op_params,
1371                   const RuntimeShape& input_shape,
1372                   const RuntimeShape& output_shape,
1373                   SequentialTensorWriter<T>* writer) {
1374   const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
1375   TFLITE_DCHECK_LE(op_params.begin_count, 5);
1376   TFLITE_DCHECK_LE(op_params.size_count, 5);
1377   const int begin_count = op_params.begin_count;
1378   const int size_count = op_params.size_count;
1379   // We front-pad the begin and size vectors.
1380   std::array<int, 5> start;
1381   std::array<int, 5> stop;
1382   for (int i = 0; i < 5; ++i) {
1383     int padded_i = 5 - i;
1384     start[i] =
1385         begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
1386     stop[i] =
1387         (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
1388             ? ext_shape.Dims(i)
1389             : start[i] + op_params.size[size_count - padded_i];
1390   }
1391 
1392   for (int i0 = start[0]; i0 < stop[0]; ++i0) {
1393     for (int i1 = start[1]; i1 < stop[1]; ++i1) {
1394       for (int i2 = start[2]; i2 < stop[2]; ++i2) {
1395         for (int i3 = start[3]; i3 < stop[3]; ++i3) {
1396           for (int i4 = start[4]; i4 < stop[4]; ++i4) {
1397             writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
1398           }
1399         }
1400       }
1401     }
1402   }
1403 }
1404 
1405 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)1406 inline void Slice(const tflite::SliceParams& op_params,
1407                   const RuntimeShape& input_shape, const T* input_data,
1408                   const RuntimeShape& output_shape, T* output_data) {
1409   SequentialTensorWriter<T> writer(input_data, output_data);
1410   return Slice(op_params, input_shape, output_shape, &writer);
1411 }
1412 
1413 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)1414 inline void Slice(const tflite::SliceParams& op_params,
1415                   const RuntimeShape& input_shape, const TfLiteTensor* input,
1416                   const RuntimeShape& output_shape, TfLiteTensor* output) {
1417   SequentialTensorWriter<T> writer(input, output);
1418   return Slice(op_params, input_shape, output_shape, &writer);
1419 }
1420 
1421 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1422 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1423              const T* input2_data, const RuntimeShape& output_shape,
1424              T* output_data) {
1425   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1426 
1427   auto min_value = input2_data[0];
1428   for (int i = 0; i < flat_size; i++) {
1429     output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
1430   }
1431 }
1432 
1433 // Convenience version that allows, for example, generated-code calls to be
1434 // the same as other binary ops.
1435 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1436 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1437                     const RuntimeShape&, const T* input2_data,
1438                     const RuntimeShape& output_shape, T* output_data) {
1439   // Drop shape of second input: not needed.
1440   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
1441 }
1442 
1443 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1444 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1445              const T* input2_data, const RuntimeShape& output_shape,
1446              T* output_data) {
1447   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1448 
1449   auto max_value = input2_data[0];
1450   for (int i = 0; i < flat_size; i++) {
1451     output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
1452   }
1453 }
1454 
1455 // Convenience version that allows, for example, generated-code calls to be
1456 // the same as other binary ops.
1457 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1458 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1459                     const RuntimeShape&, const T* input2_data,
1460                     const RuntimeShape& output_shape, T* output_data) {
1461   // Drop shape of second input: not needed.
1462   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
1463 }
1464 
1465 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1466 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1467             const T3* input2_data, const RuntimeShape& output_shape,
1468             T2* output_data) {
1469   ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
1470             std::greater<T1>());
1471 }
1472 
1473 // Convenience version that allows, for example, generated-code calls to be
1474 // the same as other binary ops.
1475 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1476 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1477                    const RuntimeShape& input2_shape, const T3* input2_data,
1478                    const RuntimeShape& output_shape, T2* output_data) {
1479   // Drop shape of second input: not needed.
1480   ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
1481 }
1482 
1483 template <typename D, typename T>
Select(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1484 void Select(const RuntimeShape& input_condition_shape,
1485             const D* input_condition_data, const RuntimeShape& input_x_shape,
1486             const T* input_x_data, const RuntimeShape& input_y_shape,
1487             const T* input_y_data, const RuntimeShape& output_shape,
1488             T* output_data) {
1489   const int64_t flatsize = MatchingFlatSize(
1490       input_condition_shape, input_x_shape, input_y_shape, output_shape);
1491   for (int64_t i = 0; i < flatsize; ++i) {
1492     output_data[i] =
1493         input_condition_data[i] ? input_x_data[i] : input_y_data[i];
1494   }
1495 }
1496 
1497 template <typename D, typename T>
RankOneSelect(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1498 void RankOneSelect(const RuntimeShape& input_condition_shape,
1499                    const D* input_condition_data,
1500                    const RuntimeShape& input_x_shape, const T* input_x_data,
1501                    const RuntimeShape& input_y_shape, const T* input_y_data,
1502                    const RuntimeShape& output_shape, T* output_data) {
1503   const int64_t outer_size = input_condition_shape.FlatSize();
1504   int64_t inner_size;
1505   if (input_condition_shape.DimensionsCount() == 0) {
1506     inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
1507   } else {
1508     TFLITE_DCHECK_EQ(
1509         MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
1510         outer_size);
1511     inner_size =
1512         MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
1513   }
1514 
1515   int64_t offset = 0;
1516   for (int64_t i = 0; i < outer_size; i++) {
1517     const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
1518     memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
1519     offset += inner_size;
1520   }
1521 }
1522 
1523 template <typename D, typename T>
BroadcastSelect4DSlow(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1524 void BroadcastSelect4DSlow(const RuntimeShape& input_condition_shape,
1525                            const D* input_condition_data,
1526                            const RuntimeShape& input_x_shape,
1527                            const T* input_x_data,
1528                            const RuntimeShape& input_y_shape,
1529                            const T* input_y_data,
1530                            const RuntimeShape& output_shape, T* output_data) {
1531   TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
1532   TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
1533   TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
1534   TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
1535 
1536   const RuntimeShape extended_output_shape =
1537       RuntimeShape::ExtendedShape(4, output_shape);
1538 
1539   NdArrayDesc<4> desc_condition;
1540   NdArrayDesc<4> desc_x;
1541   NdArrayDesc<4> desc_y;
1542   NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
1543                                       input_y_shape, &desc_condition, &desc_x,
1544                                       &desc_y);
1545 
1546   // In Tensorflow, the dimensions are canonically named (batch_number, row,
1547   // col, channel), with extents (batches, height, width, depth), with the
1548   // trailing dimension changing most rapidly (channels has the smallest
1549   // stride, typically 1 element).
1550   //
1551   // In generated C code, we store arrays with the dimensions reversed. The
1552   // first dimension has smallest stride.
1553   //
1554   // We name our variables by their Tensorflow convention, but generate C code
1555   // nesting loops such that the innermost loop has the smallest stride for
1556   // the best cache behavior.
1557   for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
1558     for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
1559       for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
1560         for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
1561           const int condition_index =
1562               SubscriptToIndex(desc_condition, b, y, x, c);
1563           const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
1564           const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
1565           output_data[Offset(extended_output_shape, b, y, x, c)] =
1566               input_condition_data[condition_index] ? input_x_data[x_index]
1567                                                     : input_y_data[y_index];
1568         }
1569       }
1570     }
1571   }
1572 }
1573 
1574 template <typename D, typename T>
SelectTrueCoords(const RuntimeShape & input_condition_shape,const D * input_condition_data,T * output_data)1575 void SelectTrueCoords(const RuntimeShape& input_condition_shape,
1576                       const D* input_condition_data, T* output_data) {
1577   const size_t size = input_condition_shape.FlatSize();
1578   if (size == 0) {
1579     // Dimension is zero, in which case we don't need to output.
1580     return;
1581   }
1582   const size_t cond_rank = input_condition_shape.DimensionsCount();
1583 
1584   std::vector<int> dims_to_count(cond_rank, 0);
1585   int cur_flat_size = size;
1586   for (int i = 0; i < cond_rank; ++i) {
1587     dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
1588     cur_flat_size = dims_to_count[i];
1589   }
1590 
1591   int output_index = 0;
1592   for (int i = 0; i < size; ++i) {
1593     if (input_condition_data[i]) {
1594       // Insert the coordinate of the current item (row major) into output.
1595       int flat_index = i;
1596       for (int j = 0; j < cond_rank; ++j) {
1597         int coord_j = flat_index / dims_to_count[j];
1598         output_data[output_index * cond_rank + j] = coord_j;
1599         flat_index %= dims_to_count[j];
1600       }
1601       output_index++;
1602     }
1603   }
1604 }
1605 
1606 // For easy implementation, the indices is always a vector of size-4 vectors.
1607 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,bool value_is_scalar,const RuntimeShape & unextended_output_shape,T * output_data)1608 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1609                           const T* values, T default_value,
1610                           bool value_is_scalar,
1611                           const RuntimeShape& unextended_output_shape,
1612                           T* output_data) {
1613   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1614   const RuntimeShape output_shape =
1615       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1616   const int value_count = indices.size();
1617 
1618   // First fill the output_data with default value.
1619   const int num_elements = output_shape.FlatSize();
1620   for (int i = 0; i < num_elements; ++i) {
1621     output_data[i] = default_value;
1622   }
1623 
1624   // Special handle for value is scalar case to avoid checking the boolean
1625   // condition within the loop every time.
1626   if (value_is_scalar) {
1627     for (int i = 0; i < value_count; ++i) {
1628       const std::vector<TI>& index = indices[i];
1629       TFLITE_DCHECK_EQ(index.size(), 4);
1630       const T value = *values;  // just use the first value.
1631       output_data[Offset(output_shape, index[0], index[1], index[2],
1632                          index[3])] = value;
1633     }
1634     return;
1635   }
1636 
1637   // Go through the values and indices to fill the sparse values.
1638   for (int i = 0; i < value_count; ++i) {
1639     const std::vector<TI>& index = indices[i];
1640     TFLITE_DCHECK_EQ(index.size(), 4);
1641     const T value = values[i];
1642     output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
1643         value;
1644   }
1645 }
1646 
1647 template <typename T>
Pow(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1648 inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
1649                 const RuntimeShape& input2_shape, const T* input2_data,
1650                 const RuntimeShape& output_shape, T* output_data) {
1651   const int flat_size =
1652       MatchingFlatSize(input1_shape, input2_shape, output_shape);
1653   for (int i = 0; i < flat_size; ++i) {
1654     output_data[i] = std::pow(input1_data[i], input2_data[i]);
1655   }
1656 }
1657 
1658 template <typename T>
BroadcastPow4DSlow(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)1659 inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
1660                                const T* input1_data,
1661                                const RuntimeShape& unextended_input2_shape,
1662                                const T* input2_data,
1663                                const RuntimeShape& unextended_output_shape,
1664                                T* output_data) {
1665   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
1666   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
1667   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1668   const RuntimeShape output_shape =
1669       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1670 
1671   NdArrayDesc<4> desc1;
1672   NdArrayDesc<4> desc2;
1673   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
1674                                       unextended_input2_shape, &desc1, &desc2);
1675 
1676   for (int b = 0; b < output_shape.Dims(0); ++b) {
1677     for (int y = 0; y < output_shape.Dims(1); ++y) {
1678       for (int x = 0; x < output_shape.Dims(2); ++x) {
1679         for (int c = 0; c < output_shape.Dims(3); ++c) {
1680           auto out_idx = Offset(output_shape, b, y, x, c);
1681           auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
1682           auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
1683           auto in1_val = input1_data[in1_idx];
1684           auto in2_val = input2_data[in2_idx];
1685           output_data[out_idx] = std::pow(in1_val, in2_val);
1686         }
1687       }
1688     }
1689   }
1690 }
1691 
1692 template <typename Scalar>
Reverse(int axis,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1693 void Reverse(int axis, const RuntimeShape& input_shape,
1694              const Scalar* input_data, const RuntimeShape& output_shape,
1695              Scalar* output_data) {
1696   ruy::profiler::ScopeLabel label("Reverse");
1697 
1698   int outer_size = 1;
1699   for (int i = 0; i < axis; ++i) {
1700     outer_size *= input_shape.Dims(i);
1701   }
1702 
1703   int copy_size = 1;
1704   for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
1705     copy_size *= input_shape.Dims(i);
1706   }
1707 
1708   const int dims_at_axis = input_shape.Dims(axis);
1709   for (int i = 0; i < outer_size; ++i) {
1710     for (int j = 0; j < dims_at_axis; ++j) {
1711       const int start_pos = (i * dims_at_axis + j) * copy_size;
1712       Scalar* output_ptr = output_data + start_pos;
1713       int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
1714       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
1715     }
1716   }
1717 }
1718 
1719 template <typename Scalar, typename TS>
ReverseSequence(const TS * seq_lengths,const int seq_dim,const int batch_dim,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1720 void ReverseSequence(const TS* seq_lengths, const int seq_dim,
1721                      const int batch_dim, const RuntimeShape& input_shape,
1722                      const Scalar* input_data, const RuntimeShape& output_shape,
1723                      Scalar* output_data) {
1724   ruy::profiler::ScopeLabel label("ReverseSequence");
1725 
1726   int outer_size = 1;
1727   int outer_dim = std::min(batch_dim, seq_dim);
1728   int medium_dim = std::max(batch_dim, seq_dim);
1729   for (int i = 0; i < outer_dim; ++i) {
1730     outer_size *= input_shape.Dims(i);
1731   }
1732 
1733   int medium_size = 1;
1734   for (int i = outer_dim + 1; i < medium_dim; ++i) {
1735     medium_size *= input_shape.Dims(i);
1736   }
1737 
1738   int copy_size = 1;
1739   for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i) {
1740     copy_size *= input_shape.Dims(i);
1741   }
1742 
1743   const int dims_at_outer_dim = input_shape.Dims(outer_dim);
1744   const int dims_at_medium_dim = input_shape.Dims(medium_dim);
1745 
1746   Scalar* output_ptr;
1747   if (batch_dim > seq_dim) {
1748     for (int i = 0; i < outer_size; ++i) {
1749       for (int j = 0; j < dims_at_outer_dim; ++j) {
1750         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1751         for (int p = 0; p < medium_size; ++p) {
1752           for (int q = 0; q < dims_at_medium_dim; ++q) {
1753             const int in_pos =
1754                 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1755             const Scalar* in_ptr = input_data + in_pos;
1756             int sl = seq_lengths[q] - 1;
1757             if (j > sl) {
1758               output_ptr = output_data + in_pos;
1759             } else {
1760               const int out_pos_base =
1761                   (i * dims_at_outer_dim + sl - j) * medium_size;
1762               const int out_pos =
1763                   ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1764               output_ptr = output_data + out_pos;
1765             }
1766             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1767           }
1768         }
1769       }
1770     }
1771   } else if (batch_dim < seq_dim) {
1772     for (int i = 0; i < outer_size; ++i) {
1773       for (int j = 0; j < dims_at_outer_dim; ++j) {
1774         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1775         int sl = seq_lengths[j] - 1;
1776         const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1777         for (int p = 0; p < medium_size; ++p) {
1778           for (int q = 0; q < dims_at_medium_dim; ++q) {
1779             const int in_pos =
1780                 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1781             const Scalar* in_ptr = input_data + in_pos;
1782             if (q > sl) {
1783               output_ptr = output_data + in_pos;
1784             } else {
1785               const int out_pos =
1786                   ((out_pos_base + p) * dims_at_medium_dim + sl - q) *
1787                   copy_size;
1788               output_ptr = output_data + out_pos;
1789             }
1790             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1791           }
1792         }
1793       }
1794     }
1795   }
1796 }
1797 
1798 template <typename T>
SegmentSum(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)1799 inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
1800                        const RuntimeShape& segment_ids_shape,
1801                        const int32_t* segment_ids_data,
1802                        const RuntimeShape& output_shape, T* output_data) {
1803   const int segment_flat_size =
1804       MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1805 
1806   memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1807 
1808   for (int i = 0; i < input_shape.Dims(0); i++) {
1809     int output_index = segment_ids_data[i];
1810     for (int j = 0; j < segment_flat_size; ++j) {
1811       output_data[output_index * segment_flat_size + j] +=
1812           input_data[i * segment_flat_size + j];
1813     }
1814   }
1815 }
1816 
1817 }  // namespace reference_ops
1818 }  // namespace tflite
1819 
1820 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
1821