1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/legacy_types.h"
23 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
24 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27 
28 namespace tflite {
29 
30 namespace reference_ops {
31 
32 static constexpr int kDepthwiseReverseShift = -1;
33 
ShapeFromDims(const tflite::Dims<4> & dims,RuntimeShape * shape)34 inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
35   shape->BuildFrom(
36       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
37 }
38 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)39 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
40                           const float* filter_data, const Dims<4>& filter_dims,
41                           const float* bias_data, const Dims<4>& bias_dims,
42                           int stride_width, int stride_height,
43                           int dilation_width_factor, int dilation_height_factor,
44                           int pad_width, int pad_height, int depth_multiplier,
45                           float output_activation_min,
46                           float output_activation_max, float* output_data,
47                           const Dims<4>& output_dims) {
48   tflite::DepthwiseParams op_params;
49   // Padding type is ignored, but still set.
50   op_params.padding_type = PaddingType::kSame;
51   op_params.padding_values.width = pad_width;
52   op_params.padding_values.height = pad_height;
53   op_params.stride_width = stride_width;
54   op_params.stride_height = stride_height;
55   op_params.dilation_width_factor = dilation_width_factor;
56   op_params.dilation_height_factor = dilation_height_factor;
57   op_params.depth_multiplier = depth_multiplier;
58   op_params.float_activation_min = output_activation_min;
59   op_params.float_activation_max = output_activation_max;
60 
61   DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
62                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
63                 bias_data, DimsToShape(output_dims), output_data);
64 }
65 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)66 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
67                           const float* filter_data, const Dims<4>& filter_dims,
68                           const float* bias_data, const Dims<4>& bias_dims,
69                           int stride_width, int stride_height, int pad_width,
70                           int pad_height, int depth_multiplier,
71                           float output_activation_min,
72                           float output_activation_max, float* output_data,
73                           const Dims<4>& output_dims) {
74   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
75                 bias_dims, stride_width, stride_height, 1, 1, pad_width,
76                 pad_height, depth_multiplier, output_activation_min,
77                 output_activation_max, output_data, output_dims);
78 }
79 
80 // Legacy, for compatibility with old checked-in code.
81 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)82 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
83                    const float* filter_data, const Dims<4>& filter_dims,
84                    const float* bias_data, const Dims<4>& bias_dims,
85                    int stride_width, int stride_height, int pad_width,
86                    int pad_height, int depth_multiplier, float* output_data,
87                    const Dims<4>& output_dims) {
88   float output_activation_min, output_activation_max;
89   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
90   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
91                 bias_dims, stride_width, stride_height, pad_width, pad_height,
92                 depth_multiplier, output_activation_min, output_activation_max,
93                 output_data, output_dims);
94 }
95 
96 // Legacy, for compatibility with old checked-in code.
97 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)98 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
99                    const float* filter_data, const Dims<4>& filter_dims,
100                    const float* bias_data, const Dims<4>& bias_dims, int stride,
101                    int pad_width, int pad_height, int depth_multiplier,
102                    float* output_data, const Dims<4>& output_dims) {
103   DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
104                     bias_dims, stride, stride, pad_width, pad_height,
105                     depth_multiplier, output_data, output_dims);
106 }
107 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)108 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
109                           int32 input_offset, const uint8* filter_data,
110                           const Dims<4>& filter_dims, int32 filter_offset,
111                           const int32* bias_data, const Dims<4>& bias_dims,
112                           int stride_width, int stride_height,
113                           int dilation_width_factor, int dilation_height_factor,
114                           int pad_width, int pad_height, int depth_multiplier,
115                           int32 output_offset, int32 output_multiplier,
116                           int output_shift, int32 output_activation_min,
117                           int32 output_activation_max, uint8* output_data,
118                           const Dims<4>& output_dims) {
119   tflite::DepthwiseParams op_params;
120   // Padding type is ignored, but still set.
121   op_params.padding_type = PaddingType::kSame;
122   op_params.padding_values.width = pad_width;
123   op_params.padding_values.height = pad_height;
124   op_params.stride_width = stride_width;
125   op_params.stride_height = stride_height;
126   op_params.dilation_width_factor = dilation_width_factor;
127   op_params.dilation_height_factor = dilation_height_factor;
128   op_params.depth_multiplier = depth_multiplier;
129   op_params.quantized_activation_min = output_activation_min;
130   op_params.quantized_activation_max = output_activation_max;
131   op_params.input_offset = input_offset;
132   op_params.weights_offset = filter_offset;
133   op_params.output_offset = output_offset;
134   op_params.output_multiplier = output_multiplier;
135   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
136   op_params.output_shift = kDepthwiseReverseShift * output_shift;
137 
138   DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
139                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
140                 bias_data, DimsToShape(output_dims), output_data);
141 }
142 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)143 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
144                           int32 input_offset, const uint8* filter_data,
145                           const Dims<4>& filter_dims, int32 filter_offset,
146                           const int32* bias_data, const Dims<4>& bias_dims,
147                           int stride_width, int stride_height, int pad_width,
148                           int pad_height, int depth_multiplier,
149                           int32 output_offset, int32 output_multiplier,
150                           int output_shift, int32 output_activation_min,
151                           int32 output_activation_max, uint8* output_data,
152                           const Dims<4>& output_dims) {
153   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
154                 filter_offset, bias_data, bias_dims, stride_width,
155                 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
156                 output_offset, output_multiplier, output_shift,
157                 output_activation_min, output_activation_max, output_data,
158                 output_dims);
159 }
160 
161 // Legacy, for compatibility with old checked-in code.
162 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)163 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
164                    int32 input_offset, const uint8* filter_data,
165                    const Dims<4>& filter_dims, int32 filter_offset,
166                    const int32* bias_data, const Dims<4>& bias_dims,
167                    int stride_width, int stride_height, int pad_width,
168                    int pad_height, int depth_multiplier, int32 output_offset,
169                    int32 output_multiplier, int output_shift,
170                    int32 output_activation_min, int32 output_activation_max,
171                    uint8* output_data, const Dims<4>& output_dims) {
172   if (Ac == FusedActivationFunctionType::kNone) {
173     TFLITE_DCHECK_EQ(output_activation_min, 0);
174     TFLITE_DCHECK_EQ(output_activation_max, 255);
175   }
176   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
177                 filter_offset, bias_data, bias_dims, stride_width,
178                 stride_height, pad_width, pad_height, depth_multiplier,
179                 output_offset, output_multiplier, output_shift,
180                 output_activation_min, output_activation_max, output_data,
181                 output_dims);
182 }
183 
184 // Legacy, for compatibility with old checked-in code.
185 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)186 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
187                    int32 input_offset, const uint8* filter_data,
188                    const Dims<4>& filter_dims, int32 filter_offset,
189                    const int32* bias_data, const Dims<4>& bias_dims, int stride,
190                    int pad_width, int pad_height, int depth_multiplier,
191                    int32 output_offset, int32 output_multiplier,
192                    int output_shift, int32 output_activation_min,
193                    int32 output_activation_max, uint8* output_data,
194                    const Dims<4>& output_dims) {
195   DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
196                     filter_dims, filter_offset, bias_data, bias_dims, stride,
197                     stride, pad_width, pad_height, depth_multiplier,
198                     output_offset, output_multiplier, output_shift,
199                     output_activation_min, output_activation_max, output_data,
200                     output_dims);
201 }
202 
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)203 inline void Conv(const float* input_data, const Dims<4>& input_dims,
204                  const float* filter_data, const Dims<4>& filter_dims,
205                  const float* bias_data, const Dims<4>& bias_dims,
206                  int stride_width, int stride_height, int dilation_width_factor,
207                  int dilation_height_factor, int pad_width, int pad_height,
208                  float output_activation_min, float output_activation_max,
209                  float* output_data, const Dims<4>& output_dims,
210                  float* im2col_data, const Dims<4>& im2col_dims) {
211   tflite::ConvParams op_params;
212   // Padding type is ignored, but still set.
213   op_params.padding_type = PaddingType::kSame;
214   op_params.padding_values.width = pad_width;
215   op_params.padding_values.height = pad_height;
216   op_params.stride_width = stride_width;
217   op_params.stride_height = stride_height;
218   op_params.dilation_width_factor = dilation_width_factor;
219   op_params.dilation_height_factor = dilation_height_factor;
220   op_params.float_activation_min = output_activation_min;
221   op_params.float_activation_max = output_activation_max;
222 
223   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
224        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
225        output_data, DimsToShape(im2col_dims), im2col_data);
226 }
227 
228 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)229 void Conv(const float* input_data, const Dims<4>& input_dims,
230           const float* filter_data, const Dims<4>& filter_dims,
231           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
232           int stride_height, int dilation_width_factor,
233           int dilation_height_factor, int pad_width, int pad_height,
234           float* output_data, const Dims<4>& output_dims, float* im2col_data,
235           const Dims<4>& im2col_dims) {
236   float output_activation_min, output_activation_max;
237   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
238   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
239        stride_width, stride_height, dilation_width_factor,
240        dilation_height_factor, pad_width, pad_height, output_activation_min,
241        output_activation_max, output_data, output_dims, im2col_data,
242        im2col_dims);
243 }
244 
245 // legacy, for compatibility with old checked-in code
246 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)247 void Conv(const float* input_data, const Dims<4>& input_dims,
248           const float* filter_data, const Dims<4>& filter_dims,
249           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
250           int stride_height, int pad_width, int pad_height, float* output_data,
251           const Dims<4>& output_dims, float* im2col_data,
252           const Dims<4>& im2col_dims) {
253   float output_activation_min, output_activation_max;
254   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
255   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
256        stride_width, stride_height, 1, 1, pad_width, pad_height,
257        output_activation_min, output_activation_max, output_data, output_dims,
258        im2col_data, im2col_dims);
259 }
260 
261 // legacy, for compatibility with old checked-in code
262 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)263 void Conv(const float* input_data, const Dims<4>& input_dims,
264           const float* filter_data, const Dims<4>& filter_dims,
265           const float* bias_data, const Dims<4>& bias_dims, int stride,
266           int pad_width, int pad_height, float* output_data,
267           const Dims<4>& output_dims, float* im2col_data,
268           const Dims<4>& im2col_dims) {
269   Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
270            bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
271            output_dims, im2col_data, im2col_dims);
272 }
273 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)274 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
275                  int32 input_offset, const uint8* filter_data,
276                  const Dims<4>& filter_dims, int32 filter_offset,
277                  const int32* bias_data, const Dims<4>& bias_dims,
278                  int stride_width, int stride_height, int dilation_width_factor,
279                  int dilation_height_factor, int pad_width, int pad_height,
280                  int32 output_offset, int32 output_multiplier, int output_shift,
281                  int32 output_activation_min, int32 output_activation_max,
282                  uint8* output_data, const Dims<4>& output_dims,
283                  uint8* im2col_data, const Dims<4>& im2col_dims,
284                  gemmlowp::GemmContext* gemm_context) {
285   tflite::ConvParams op_params;
286   // Padding type is ignored, but still set.
287   op_params.padding_type = PaddingType::kSame;
288   op_params.padding_values.width = pad_width;
289   op_params.padding_values.height = pad_height;
290   op_params.stride_width = stride_width;
291   op_params.stride_height = stride_height;
292   op_params.dilation_width_factor = dilation_width_factor;
293   op_params.dilation_height_factor = dilation_height_factor;
294   op_params.input_offset = input_offset;
295   op_params.weights_offset = filter_offset;
296   op_params.output_offset = output_offset;
297   op_params.output_multiplier = output_multiplier;
298   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
299   op_params.output_shift = kReverseShift * output_shift;
300   op_params.quantized_activation_min = output_activation_min;
301   op_params.quantized_activation_max = output_activation_max;
302 
303   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
304        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
305        output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
306 }
307 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)308 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
309                  int32 input_offset, const uint8* filter_data,
310                  const Dims<4>& filter_dims, int32 filter_offset,
311                  const int32* bias_data, const Dims<4>& bias_dims,
312                  int stride_width, int stride_height, int pad_width,
313                  int pad_height, int32 output_offset, int32 output_multiplier,
314                  int output_shift, int32 output_activation_min,
315                  int32 output_activation_max, uint8* output_data,
316                  const Dims<4>& output_dims, uint8* im2col_data,
317                  const Dims<4>& im2col_dims,
318                  gemmlowp::GemmContext* gemm_context) {
319   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
320        filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
321        pad_width, pad_height, output_offset, output_multiplier, output_shift,
322        output_activation_min, output_activation_max, output_data, output_dims,
323        im2col_data, im2col_dims, gemm_context);
324 }
325 
326 // legacy, for compatibility with old checked-in code
327 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)328 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
329                  int32 input_offset, const uint8* filter_data,
330                  const Dims<4>& filter_dims, int32 filter_offset,
331                  const int32* bias_data, const Dims<4>& bias_dims,
332                  int stride_width, int stride_height, int pad_width,
333                  int pad_height, int32 output_offset, int32 output_multiplier,
334                  int output_shift, int32 output_activation_min,
335                  int32 output_activation_max, uint8* output_data,
336                  const Dims<4>& output_dims, uint8* im2col_data,
337                  const Dims<4>& im2col_dims,
338                  gemmlowp::GemmContext* gemm_context) {
339   static_assert(Ac == FusedActivationFunctionType::kNone ||
340                     Ac == FusedActivationFunctionType::kRelu ||
341                     Ac == FusedActivationFunctionType::kRelu6 ||
342                     Ac == FusedActivationFunctionType::kRelu1,
343                 "");
344   if (Ac == FusedActivationFunctionType::kNone) {
345     TFLITE_DCHECK_EQ(output_activation_min, 0);
346     TFLITE_DCHECK_EQ(output_activation_max, 255);
347   }
348   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
349        filter_offset, bias_data, bias_dims, stride_width, stride_height,
350        pad_width, pad_height, output_offset, output_multiplier, output_shift,
351        output_activation_min, output_activation_max, output_data, output_dims,
352        im2col_data, im2col_dims, gemm_context);
353 }
354 
355 // legacy, for compatibility with old checked-in code
356 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)357 void Conv(const uint8* input_data, const Dims<4>& input_dims,
358           int32 input_offset, const uint8* filter_data,
359           const Dims<4>& filter_dims, int32 filter_offset,
360           const int32* bias_data, const Dims<4>& bias_dims, int stride,
361           int pad_width, int pad_height, int32 output_offset,
362           int32 output_multiplier, int output_shift,
363           int32 output_activation_min, int32 output_activation_max,
364           uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
365           const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
366   Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
367            filter_offset, bias_data, bias_dims, stride, stride, pad_width,
368            pad_height, output_offset, output_multiplier, output_shift,
369            output_activation_min, output_activation_max, output_data,
370            output_dims, im2col_data, im2col_dims, gemm_context);
371 }
372 
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)373 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
374                           const float* filter_data, const Dims<4>& filter_dims,
375                           int stride_width, int stride_height, int pad_width,
376                           int pad_height, float* output_data,
377                           const Dims<4>& output_dims, float* im2col_data,
378                           const Dims<4>& im2col_dims) {
379   tflite::ConvParams op_params;
380   // Padding type is ignored, but still set.
381   op_params.padding_type = PaddingType::kSame;
382   op_params.padding_values.width = pad_width;
383   op_params.padding_values.height = pad_height;
384   op_params.stride_width = stride_width;
385   op_params.stride_height = stride_height;
386 
387   TransposeConv(op_params, DimsToShape(input_dims), input_data,
388                 DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
389                 output_data, DimsToShape(im2col_dims), im2col_data);
390 }
391 
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)392 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
393                            const float* weights_data,
394                            const Dims<4>& weights_dims, const float* bias_data,
395                            const Dims<4>& bias_dims,
396                            float output_activation_min,
397                            float output_activation_max, float* output_data,
398                            const Dims<4>& output_dims) {
399   tflite::FullyConnectedParams op_params;
400   op_params.float_activation_min = output_activation_min;
401   op_params.float_activation_max = output_activation_max;
402 
403   FullyConnected(op_params, DimsToShape(input_dims), input_data,
404                  DimsToShape(weights_dims), weights_data,
405                  DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
406                  output_data);
407 }
408 
409 // legacy, for compatibility with old checked-in code
410 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)411 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
412                     const float* weights_data, const Dims<4>& weights_dims,
413                     const float* bias_data, const Dims<4>& bias_dims,
414                     float* output_data, const Dims<4>& output_dims) {
415   float output_activation_min, output_activation_max;
416   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
417   FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
418                  bias_dims, output_activation_min, output_activation_max,
419                  output_data, output_dims);
420 }
421 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemm_context)422 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
423                            int32 input_offset, const uint8* filter_data,
424                            const Dims<4>& filter_dims, int32 filter_offset,
425                            const int32* bias_data, const Dims<4>& bias_dims,
426                            int32 output_offset, int32 output_multiplier,
427                            int output_shift, int32 output_activation_min,
428                            int32 output_activation_max, uint8* output_data,
429                            const Dims<4>& output_dims,
430                            gemmlowp::GemmContext* gemm_context) {
431   tflite::FullyConnectedParams op_params;
432   op_params.input_offset = input_offset;
433   op_params.weights_offset = filter_offset;
434   op_params.output_offset = output_offset;
435   op_params.output_multiplier = output_multiplier;
436   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
437   op_params.output_shift = kReverseShift * output_shift;
438   op_params.quantized_activation_min = output_activation_min;
439   op_params.quantized_activation_max = output_activation_max;
440 
441   FullyConnected(op_params, DimsToShape(input_dims), input_data,
442                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
443                  bias_data, DimsToShape(output_dims), output_data,
444                  gemm_context);
445 }
446 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemm_context)447 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
448                            int32 input_offset, const uint8* filter_data,
449                            const Dims<4>& filter_dims, int32 filter_offset,
450                            const int32* bias_data, const Dims<4>& bias_dims,
451                            int32 output_offset, int32 output_multiplier,
452                            int output_shift, int32 output_activation_min,
453                            int32 output_activation_max, int16* output_data,
454                            const Dims<4>& output_dims,
455                            gemmlowp::GemmContext* gemm_context) {
456   tflite::FullyConnectedParams op_params;
457   op_params.input_offset = input_offset;
458   op_params.weights_offset = filter_offset;
459   op_params.output_offset = output_offset;
460   op_params.output_multiplier = output_multiplier;
461   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
462   op_params.output_shift = kReverseShift * output_shift;
463   op_params.quantized_activation_min = output_activation_min;
464   op_params.quantized_activation_max = output_activation_max;
465 
466   FullyConnected(op_params, DimsToShape(input_dims), input_data,
467                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
468                  bias_data, DimsToShape(output_dims), output_data,
469                  gemm_context);
470 }
471 
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemm_context)472 inline void ShuffledFullyConnected(
473     const uint8* input_data, const Dims<4>& input_dims,
474     const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
475     const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
476     int output_shift, int32 output_activation_min, int32 output_activation_max,
477     int16* output_data, const Dims<4>& output_dims,
478     uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
479   tflite::FullyConnectedParams op_params;
480   op_params.output_multiplier = output_multiplier;
481   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
482   op_params.output_shift = kReverseShift * output_shift;
483   op_params.quantized_activation_min = output_activation_min;
484   op_params.quantized_activation_max = output_activation_max;
485 
486   ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
487                          DimsToShape(weights_dims), shuffled_weights_data,
488                          DimsToShape(bias_dims), bias_data,
489                          DimsToShape(output_dims), output_data,
490                          shuffled_input_workspace_data, gemm_context);
491 }
492 
493 // legacy, for compatibility with old checked-in code
494 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemm_context)495 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
496                     int32 input_offset, const uint8* filter_data,
497                     const Dims<4>& filter_dims, int32 filter_offset,
498                     const int32* bias_data, const Dims<4>& bias_dims,
499                     int32 output_offset, int32 output_multiplier,
500                     int output_shift, int32 output_activation_min,
501                     int32 output_activation_max, uint8* output_data,
502                     const Dims<4>& output_dims,
503                     gemmlowp::GemmContext* gemm_context) {
504   static_assert(Ac == FusedActivationFunctionType::kNone ||
505                     Ac == FusedActivationFunctionType::kRelu ||
506                     Ac == FusedActivationFunctionType::kRelu6 ||
507                     Ac == FusedActivationFunctionType::kRelu1,
508                 "");
509   if (Ac == FusedActivationFunctionType::kNone) {
510     TFLITE_DCHECK_EQ(output_activation_min, 0);
511     TFLITE_DCHECK_EQ(output_activation_max, 255);
512   }
513   FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
514                  filter_offset, bias_data, bias_dims, output_offset,
515                  output_multiplier, output_shift, output_activation_min,
516                  output_activation_max, output_data, output_dims, gemm_context);
517 }
518 
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)519 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
520                      const float* prev_activ_data,
521                      const Dims<4>& prev_activ_dims, const float* weights_data,
522                      const Dims<4>& weights_dims, const float* bias_data,
523                      const Dims<4>& bias_dims, const float* prev_state_data,
524                      const Dims<4>& prev_state_dims, float* output_state_data,
525                      const Dims<4>& output_state_dims, float* output_activ_data,
526                      const Dims<4>& output_activ_dims, float* concat_temp_data,
527                      const Dims<4>& concat_temp_dims, float* activ_temp_data,
528                      const Dims<4>& activ_temp_dims) {
529   tflite::LstmCellParams op_params;
530   // Float LSTM cell does not need parameters to be set: leave untouched.
531 
532   LstmCell(op_params, DimsToShape(input_dims), input_data,
533            DimsToShape(prev_activ_dims), prev_activ_data,
534            DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
535            bias_data, DimsToShape(prev_state_dims), prev_state_data,
536            DimsToShape(output_state_dims), output_state_data,
537            DimsToShape(output_activ_dims), output_activ_data,
538            DimsToShape(concat_temp_dims), concat_temp_data,
539            DimsToShape(activ_temp_dims), activ_temp_data);
540 }
541 
542 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemm_context)543 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
544               const uint8* prev_activ_data_uint8,
545               const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
546               const Dims<4>& weights_dims, const int32* bias_data_int32,
547               const Dims<4>& bias_dims, const int16* prev_state_data_int16,
548               const Dims<4>& prev_state_dims, int16* output_state_data_int16,
549               const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
550               const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
551               const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
552               const Dims<4>& activ_temp_dims, int32 weights_zero_point,
553               int32 accum_multiplier, int accum_shift,
554               gemmlowp::GemmContext* gemm_context) {
555   tflite::LstmCellParams op_params;
556   op_params.weights_zero_point = weights_zero_point;
557   op_params.accum_multiplier = accum_multiplier;
558   op_params.accum_shift = accum_shift;
559 
560   LstmCell<StateIntegerBits>(
561       op_params, DimsToShape(input_dims), input_data_uint8,
562       DimsToShape(prev_activ_dims), prev_activ_data_uint8,
563       DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
564       bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
565       DimsToShape(output_state_dims), output_state_data_int16,
566       DimsToShape(output_activ_dims), output_activ_data_uint8,
567       DimsToShape(concat_temp_dims), concat_temp_data_uint8,
568       DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
569 }
570 
571 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)572 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
573                   const T* input2_data, const Dims<4>& input2_dims,
574                   T output_activation_min, T output_activation_max,
575                   T* output_data, const Dims<4>& output_dims) {
576   tflite::ArithmeticParams op_params;
577   SetActivationParams(output_activation_min, output_activation_max, &op_params);
578 
579   BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
580                      DimsToShape(input2_dims), input2_data,
581                      DimsToShape(output_dims), output_data);
582 }
583 
584 template <typename T>
Div(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)585 inline void Div(const T* input1_data, const Dims<4>& input1_dims,
586                 const T* input2_data, const Dims<4>& input2_dims,
587                 T output_activation_min, T output_activation_max,
588                 T* output_data, const Dims<4>& output_dims) {
589   tflite::ArithmeticParams op_params;
590   SetActivationParams(output_activation_min, output_activation_max, &op_params);
591 
592   Div(op_params, DimsToShape(input1_dims), input1_data,
593       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
594       output_data);
595 }
596 
597 template <FusedActivationFunctionType Ac, typename Scalar>
Concatenation(int concat_dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)598 inline void Concatenation(int concat_dim, const Scalar* const* input_data,
599                           const Dims<4>* const* input_dims, int inputs_count,
600                           Scalar* output_data, const Dims<4>& output_dims) {
601   // For now we don't have a model with a Concatenation with fused activation.
602   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
603 
604   std::vector<RuntimeShape> input_shapes(inputs_count);
605   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
606   for (int i = 0; i < inputs_count; ++i) {
607     ShapeFromDims(*input_dims[i], &input_shapes[i]);
608     input_shapes_indirect[i] = &input_shapes[i];
609   }
610   tflite::ConcatenationParams op_params;
611   op_params.axis = 3 - concat_dim;
612   op_params.inputs_count = inputs_count;
613 
614   Concatenation(op_params, input_shapes_indirect.data(), input_data,
615                 DimsToShape(output_dims), output_data);
616 }
617 
Concatenation(int concat_dim,const uint8 * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,uint8 * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)618 inline void Concatenation(int concat_dim, const uint8* const* input_data,
619                           const Dims<4>* const* input_dims,
620                           const int32* input_zeropoint,
621                           const float* input_scale, int inputs_count,
622                           uint8* output_data, const Dims<4>& output_dims,
623                           const int32 output_zeropoint,
624                           const float output_scale) {
625   std::vector<RuntimeShape> input_shapes(inputs_count);
626   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
627   for (int i = 0; i < inputs_count; ++i) {
628     ShapeFromDims(*input_dims[i], &input_shapes[i]);
629     input_shapes_indirect[i] = &input_shapes[i];
630   }
631   tflite::ConcatenationParams op_params;
632   op_params.axis = 3 - concat_dim;
633   op_params.input_zeropoint = input_zeropoint;
634   op_params.input_scale = input_scale;
635   op_params.inputs_count = inputs_count;
636   op_params.output_zeropoint = output_zeropoint;
637   op_params.output_scale = output_scale;
638 
639   ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
640                            DimsToShape(output_dims), output_data);
641 }
642 
643 template <FusedActivationFunctionType Ac, typename Scalar>
DepthConcatenation(const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)644 void DepthConcatenation(const Scalar* const* input_data,
645                         const Dims<4>* const* input_dims, int inputs_count,
646                         Scalar* output_data, const Dims<4>& output_dims) {
647   // For now we don't have a model with a Concatenation with fused activation.
648   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
649 
650   std::vector<RuntimeShape> input_shapes(inputs_count);
651   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
652   for (int i = 0; i < inputs_count; ++i) {
653     ShapeFromDims(*input_dims[i], &input_shapes[i]);
654     input_shapes_indirect[i] = &input_shapes[i];
655   }
656   tflite::ConcatenationParams op_params;
657   op_params.inputs_count = inputs_count;
658 
659   DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
660                      DimsToShape(output_dims), output_data);
661 }
662 
663 template <typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int axis,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)664 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
665                      int axis, int outputs_count, Scalar* const* output_data,
666                      const Dims<4>* const* output_dims) {
667   std::vector<RuntimeShape> output_shapes(outputs_count);
668   std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
669   for (int i = 0; i < outputs_count; ++i) {
670     ShapeFromDims(*output_dims[i], &output_shapes[i]);
671     output_shapes_indirect[i] = &output_shapes[i];
672   }
673   tflite::SplitParams op_params;
674   op_params.axis = 3 - axis;
675   op_params.num_split = outputs_count;
676 
677   Split(op_params, DimsToShape(input_dims), input_data,
678         output_shapes_indirect.data(), output_data);
679 }
680 
681 template <FusedActivationFunctionType Ac, typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)682 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
683                      int outputs_count, Scalar* const* output_data,
684                      const Dims<4>* const* output_dims) {
685   TFLITE_DCHECK_GE(outputs_count, 1);
686   for (int i = 0; i < outputs_count; i++) {
687     /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
688     /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
689     /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
690   }
691   // For now we don't have a model with a Split with fused activation.
692   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
693 
694   TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
695                   output_data, output_dims);
696 }
697 
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)698 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
699                     float beta, float* output_data,
700                     const RuntimeShape& output_shape) {
701   SoftmaxParams params;
702   params.beta = beta;
703   Softmax(params, input_shape, input_data, output_shape, output_data);
704 }
705 
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)706 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
707                     int32 input_beta_multiplier, int32 input_beta_left_shift,
708                     int diff_min, uint8* output_data,
709                     const RuntimeShape& output_shape) {
710   SoftmaxParams params;
711   params.input_multiplier = input_beta_multiplier;
712   params.input_left_shift = input_beta_left_shift;
713   params.diff_min = diff_min;
714   Softmax(params, input_shape, input_data, output_shape, output_data);
715 }
716 
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)717 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
718                        float* output_data, const RuntimeShape& output_shape) {
719   SoftmaxParams params;
720   // No params currently used for float LogSoftmax.
721   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
722 }
723 
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)724 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
725                        int32 input_multiplier, int32 input_left_shift,
726                        int32 reverse_scaling_divisor,
727                        int32 reverse_scaling_right_shift, int diff_min,
728                        uint8* output_data, const RuntimeShape& output_shape) {
729   SoftmaxParams params;
730   params.input_multiplier = input_multiplier;
731   params.input_left_shift = input_left_shift;
732   params.reverse_scaling_divisor = reverse_scaling_divisor;
733   params.reverse_scaling_right_shift = reverse_scaling_right_shift;
734   params.diff_min = diff_min;
735   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
736 }
737 
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)738 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
739                      int32 input_zero_point, int32 input_range_radius,
740                      int32 input_multiplier, int input_left_shift,
741                      uint8* output_data, const RuntimeShape& output_shape) {
742   LogisticParams params;
743   params.input_zero_point = input_zero_point;
744   params.input_range_radius = input_range_radius;
745   params.input_multiplier = input_multiplier;
746   params.input_left_shift = input_left_shift;
747   Logistic(params, input_shape, input_data, output_shape, output_data);
748 }
749 
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)750 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
751                      const RuntimeShape& output_shape, int16* output_data) {
752   LogisticParams params;
753   // No params currently needed by int16 Logistic.
754   Logistic(params, input_shape, input_data, output_shape, output_data);
755 }
756 
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)757 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
758                  int32 input_zero_point, int32 input_range_radius,
759                  int32 input_multiplier, int input_left_shift,
760                  uint8* output_data, const RuntimeShape& output_shape) {
761   TanhParams params;
762   params.input_zero_point = input_zero_point;
763   params.input_range_radius = input_range_radius;
764   params.input_multiplier = input_multiplier;
765   params.input_left_shift = input_left_shift;
766   Tanh(params, input_shape, input_data, output_shape, output_data);
767 }
768 
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)769 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
770                  int input_left_shift, int16* output_data,
771                  const RuntimeShape& output_shape) {
772   TanhParams params;
773   params.input_left_shift = input_left_shift;
774   Tanh(params, input_shape, input_data, output_shape, output_data);
775 }
776 
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)777 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
778                        int32 zero_point, double scale, float* output_data,
779                        const Dims<4>& output_dims) {
780   tflite::DequantizationParams op_params;
781   op_params.zero_point = zero_point;
782   op_params.scale = scale;
783 
784   Dequantize(op_params, DimsToShape(input_dims), input_data,
785              DimsToShape(output_dims), output_data);
786 }
787 
FakeQuant(const float * input_data,const Dims<4> & input_dims,float rmin,float rmax,int num_bits,float * output_data,const Dims<4> & output_dims)788 inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
789                       float rmin, float rmax, int num_bits, float* output_data,
790                       const Dims<4>& output_dims) {
791   tflite::FakeQuantParams op_params;
792   op_params.num_bits = num_bits;
793   op_params.minmax.min = rmin;
794   op_params.minmax.max = rmax;
795 
796   FakeQuant(op_params, DimsToShape(input_dims), input_data,
797             DimsToShape(output_dims), output_data);
798 }
799 
800 template <typename T>
Gather(const T * input_data,const Dims<4> & input_dims,int input_rank,const int32 * coords_data,const Dims<4> & coords_dims,T * output_data,const Dims<4> & output_dims)801 inline void Gather(const T* input_data, const Dims<4>& input_dims,
802                    int input_rank, const int32* coords_data,
803                    const Dims<4>& coords_dims, T* output_data,
804                    const Dims<4>& output_dims) {
805   tflite::GatherParams op_params;
806   op_params.axis = 4 - input_rank;
807 
808   Gather(op_params, DimsToShape(input_dims), input_data,
809          DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
810          output_data);
811 }
812 
LegacyReverseBits32(uint32 n)813 inline uint32 LegacyReverseBits32(uint32 n) {
814   n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
815   n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
816   n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
817   return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
818           ((n & 0xFF000000) >> 24));
819 }
820 
StridedSliceReverseIndices(tflite::StridedSliceParams * p)821 inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
822   TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
823   TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
824 
825   std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
826   std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
827   std::reverse(p->strides, p->strides + p->strides_count);
828 
829   p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
830                   (32 - p->start_indices_count);
831   p->ellipsis_mask =
832       LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
833       (32 - p->start_indices_count);
834   p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
835                 (32 - p->start_indices_count);
836   p->new_axis_mask =
837       LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
838       (32 - p->start_indices_count);
839   p->shrink_axis_mask =
840       LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
841       (32 - p->start_indices_count);
842 }
843 
844 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)845 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
846                          int begin_mask, int end_mask, int shrink_axis_mask,
847                          const std::vector<int>& start_indices,
848                          const std::vector<int>& stop_indices,
849                          const std::vector<int>& strides, T* output_data,
850                          const Dims<4>& output_dims) {
851   TFLITE_DCHECK_EQ(start_indices.size(), 4);
852   auto op_params = strided_slice::BuildStridedSliceParams(
853       begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
854       strides);
855   StridedSliceReverseIndices(&op_params);
856 
857   StridedSlice(op_params, DimsToShape(input_dims), input_data,
858                DimsToShape(output_dims), output_data);
859 }
860 
861 template <typename T>
Mean(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & reduction_indices,T * output_data,const Dims<4> & output_dims)862 inline void Mean(const T* input_data, const Dims<4>& input_dims,
863                  const std::vector<int>& reduction_indices, T* output_data,
864                  const Dims<4>& output_dims) {
865   tflite::MeanParams op_params;
866   op_params.axis_count = reduction_indices.size();
867   for (int i = 0; i < op_params.axis_count; ++i) {
868     op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
869   }
870 
871   Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
872        output_data);
873 }
874 
875 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)876 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
877                const Dims<4>& output_dims, const int* permuted_axes) {
878   TransposeParams params;
879   params.perm_count = 4;
880   for (int i = 0; i < 4; ++i) {
881     params.perm[i] = 3 - permuted_axes[3 - i];
882   }
883   Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
884             output);
885 }
886 
887 template <typename T, ComparisonFn<T> F>
Comparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)888 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
889                        const T* input2_data, const Dims<4>& input2_dims,
890                        bool* output_data, const Dims<4>& output_dims) {
891   ComparisonParams op_params;
892   // No parameters needed.
893   ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
894                        DimsToShape(input2_dims), input2_data,
895                        DimsToShape(output_dims), output_data);
896 }
897 
898 template <typename T, ComparisonFn<int32> F>
Comparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)899 inline void Comparison(int left_shift, const T* input1_data,
900                        const Dims<4>& input1_dims, int32 input1_offset,
901                        int32 input1_multiplier, int input1_shift,
902                        const T* input2_data, const Dims<4>& input2_dims,
903                        int32 input2_offset, int32 input2_multiplier,
904                        int input2_shift, bool* output_data,
905                        const Dims<4>& output_dims) {
906   tflite::ComparisonParams op_params;
907   op_params.left_shift = left_shift;
908   op_params.input1_offset = input1_offset;
909   op_params.input1_multiplier = input1_multiplier;
910   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
911   op_params.input1_shift = kReverseShift * input1_shift;
912   op_params.input2_offset = input2_offset;
913   op_params.input2_multiplier = input2_multiplier;
914   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
915   op_params.input2_shift = kReverseShift * input2_shift;
916 
917   ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
918                               DimsToShape(input2_dims), input2_data,
919                               DimsToShape(output_dims), output_data);
920 }
921 
922 template <typename T, ComparisonFn<T> F>
BroadcastComparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)923 inline void BroadcastComparison(const T* input1_data,
924                                 const Dims<4>& input1_dims,
925                                 const T* input2_data,
926                                 const Dims<4>& input2_dims, bool* output_data,
927                                 const Dims<4>& output_dims) {
928   ComparisonParams op_params;
929   // No parameters needed.
930   BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
931                                       input1_data, DimsToShape(input2_dims),
932                                       input2_data, DimsToShape(output_dims),
933                                       output_data);
934 }
935 
936 template <typename T, ComparisonFn<int32> F>
BroadcastComparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)937 inline void BroadcastComparison(int left_shift, const T* input1_data,
938                                 const Dims<4>& input1_dims, int32 input1_offset,
939                                 int32 input1_multiplier, int input1_shift,
940                                 const T* input2_data,
941                                 const Dims<4>& input2_dims, int32 input2_offset,
942                                 int32 input2_multiplier, int input2_shift,
943                                 bool* output_data, const Dims<4>& output_dims) {
944   ComparisonParams op_params;
945 
946   op_params.left_shift = left_shift;
947   op_params.input1_offset = input1_offset;
948   op_params.input1_multiplier = input1_multiplier;
949   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
950   op_params.input1_shift = kReverseShift * input1_shift;
951   op_params.input2_offset = input2_offset;
952   op_params.input2_multiplier = input2_multiplier;
953   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
954   op_params.input2_shift = kReverseShift * input2_shift;
955 
956   BroadcastComparison4DSlowWithScaling<T, F>(
957       op_params, DimsToShape(input1_dims), input1_data,
958       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
959       output_data);
960 }
961 
962 #define TFLITE_LEGACY_COMPARISON_OP(name)                                     \
963   template <typename T>                                                       \
964   inline void name(const T* input1_data, const Dims<4>& input1_dims,          \
965                    const T* input2_data, const Dims<4>& input2_dims,          \
966                    bool* output_data, const Dims<4>& output_dims) {           \
967     gemmlowp::ScopedProfilingLabel label(#name);                              \
968     Comparison<T, name##Fn>(input1_data, input1_dims, input2_data,            \
969                             input2_dims, output_data, output_dims);           \
970   }                                                                           \
971   template <typename T>                                                       \
972   inline void name(                                                           \
973       int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
974       int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
975       const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
976       int32 input2_multiplier, int input2_shift, bool* output_data,           \
977       const Dims<4>& output_dims) {                                           \
978     gemmlowp::ScopedProfilingLabel label(#name "/8bit");                      \
979     Comparison<T, name##Fn>(left_shift, input1_data, input1_dims,             \
980                             input1_offset, input1_multiplier, input1_shift,   \
981                             input2_data, input2_dims, input2_offset,          \
982                             input2_multiplier, input2_shift, output_data,     \
983                             output_dims);                                     \
984   }                                                                           \
985   template <typename T>                                                       \
986   inline void Broadcast##name(                                                \
987       const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
988       const Dims<4>& input2_dims, bool* output_data,                          \
989       const Dims<4>& output_dims) {                                           \
990     gemmlowp::ScopedProfilingLabel label("Broadcast" #name);                  \
991     BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data,   \
992                                      input2_dims, output_data, output_dims);  \
993   }                                                                           \
994   template <typename T>                                                       \
995   inline void Broadcast##name(                                                \
996       int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
997       int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
998       const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
999       int32 input2_multiplier, int input2_shift, bool* output_data,           \
1000       const Dims<4>& output_dims) {                                           \
1001     gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit");          \
1002     BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims,    \
1003                                      input1_offset, input1_multiplier,        \
1004                                      input1_shift, input2_data, input2_dims,  \
1005                                      input2_offset, input2_multiplier,        \
1006                                      input2_shift, output_data, output_dims); \
1007   }
1008 TFLITE_LEGACY_COMPARISON_OP(Equal);
1009 TFLITE_LEGACY_COMPARISON_OP(NotEqual);
1010 TFLITE_LEGACY_COMPARISON_OP(Greater);
1011 TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
1012 TFLITE_LEGACY_COMPARISON_OP(Less);
1013 TFLITE_LEGACY_COMPARISON_OP(LessEqual);
1014 #undef TFLITE_LEGACY_COMPARISON_OP
1015 
1016 template <typename D, typename T>
Select(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1017 inline void Select(const D* input_condition_data,
1018                    const Dims<4>& input_condition_dims, const T* input_x_data,
1019                    const Dims<4>& input_x_dims, const T* input_y_data,
1020                    const Dims<4>& input_y_dims, T* output_data,
1021                    const Dims<4>& output_dims) {
1022   Select(DimsToShape(input_condition_dims), input_condition_data,
1023          DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
1024          input_y_data, DimsToShape(output_dims), output_data);
1025 }
1026 
1027 template <typename D, typename T>
RankOneSelect(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1028 inline void RankOneSelect(const D* input_condition_data,
1029                           const Dims<4>& input_condition_dims,
1030                           const T* input_x_data, const Dims<4>& input_x_dims,
1031                           const T* input_y_data, const Dims<4>& input_y_dims,
1032                           T* output_data, const Dims<4>& output_dims) {
1033   RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
1034                 DimsToShape(input_x_dims), input_x_data,
1035                 DimsToShape(input_y_dims), input_y_data,
1036                 DimsToShape(output_dims), output_data);
1037 }
1038 
1039 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,T * output_data,const Dims<4> & output_dims,bool value_is_scalar)1040 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1041                           const T* values, T default_value, T* output_data,
1042                           const Dims<4>& output_dims, bool value_is_scalar) {
1043   SparseToDense(indices, values, default_value, value_is_scalar,
1044                 DimsToShape(output_dims), output_data);
1045 }
1046 
1047 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)1048 void Pack(int dim, const Scalar* const* input_data,
1049           const Dims<4>* const* input_dims, int inputs_count,
1050           Scalar* output_data, const Dims<4>& output_dims) {
1051   std::vector<RuntimeShape> input_shapes(inputs_count);
1052   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1053   for (int i = 0; i < inputs_count; ++i) {
1054     ShapeFromDims(*input_dims[i], &input_shapes[i]);
1055     input_shapes_indirect[i] = &input_shapes[i];
1056   }
1057   tflite::PackParams op_params;
1058   op_params.axis = 3 - dim;
1059   op_params.inputs_count = inputs_count;
1060 
1061   Pack(op_params, input_shapes_indirect.data(), input_data,
1062        DimsToShape(output_dims), output_data);
1063 }
1064 
1065 template <typename Scalar>
Unpack(int axis,const Scalar * input_data,const Dims<4> & input_dims,int dimensions,int outputs_count,Scalar * const * output_datas,const Dims<4> & output_dims)1066 void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
1067             int dimensions, int outputs_count, Scalar* const* output_datas,
1068             const Dims<4>& output_dims) {
1069   tflite::UnpackParams op_params;
1070   op_params.axis = 3 - axis;
1071   op_params.num_split = outputs_count;
1072 
1073   Unpack(op_params, DimsToShape(input_dims), input_data,
1074          DimsToShape(output_dims), output_datas);
1075 }
1076 
1077 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,Scalar * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)1078 void Pack(int dim, const Scalar* const* input_data,
1079           const Dims<4>* const* input_dims, const int32* input_zeropoint,
1080           const float* input_scale, int inputs_count, Scalar* output_data,
1081           const Dims<4>& output_dims, const int32 output_zeropoint,
1082           const float output_scale) {
1083   std::vector<RuntimeShape> input_shapes(inputs_count);
1084   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1085   for (int i = 0; i < inputs_count; ++i) {
1086     ShapeFromDims(*input_dims[i], &input_shapes[i]);
1087     input_shapes_indirect[i] = &input_shapes[i];
1088   }
1089   tflite::PackParams op_params;
1090   op_params.axis = 3 - dim;
1091   op_params.input_zeropoint = input_zeropoint;
1092   op_params.input_scale = input_scale;
1093   op_params.inputs_count = inputs_count;
1094   op_params.output_zeropoint = output_zeropoint;
1095   op_params.output_scale = output_scale;
1096 
1097   PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
1098                   DimsToShape(output_dims), output_data);
1099 }
1100 
1101 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)1102 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
1103                      float* output_data, const RuntimeShape& output_shape) {
1104   static_assert(Ac == FusedActivationFunctionType::kNone, "");
1105   tflite::L2NormalizationParams op_params;
1106   // No params need to be set for float.
1107 
1108   L2Normalization(op_params, input_shape, input_data, output_shape,
1109                   output_data);
1110 }
1111 
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)1112 inline void L2Normalization(const uint8* input_data,
1113                             const RuntimeShape& input_shape,
1114                             int32 input_zero_point, uint8* output_data,
1115                             const RuntimeShape& output_shape) {
1116   tflite::L2NormalizationParams op_params;
1117   op_params.input_zero_point = input_zero_point;
1118 
1119   L2Normalization(op_params, input_shape, input_data, output_shape,
1120                   output_data);
1121 }
1122 
1123 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1124 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
1125                      float* output_data, const Dims<4>& output_dims) {
1126   L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
1127                       DimsToShape(output_dims));
1128 }
1129 
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)1130 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
1131                             int32 input_zero_point, uint8* output_data,
1132                             const Dims<4>& output_dims) {
1133   L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
1134                   output_data, DimsToShape(output_dims));
1135 }
1136 
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1137 inline void Relu(const float* input_data, const Dims<4>& input_dims,
1138                  float* output_data, const Dims<4>& output_dims) {
1139   Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1140        output_data);
1141 }
1142 
Relu1(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1143 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
1144                   float* output_data, const Dims<4>& output_dims) {
1145   Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1146         output_data);
1147 }
1148 
Relu6(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1149 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
1150                   float* output_data, const Dims<4>& output_dims) {
1151   Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1152         output_data);
1153 }
1154 
ReluX(uint8 min_value,uint8 max_value,const uint8 * input_data,const RuntimeShape & input_shape,uint8 * output_data,const RuntimeShape & output_shape)1155 inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
1156                   const RuntimeShape& input_shape, uint8* output_data,
1157                   const RuntimeShape& output_shape) {
1158   tflite::ActivationParams params;
1159   params.quantized_activation_max = max_value;
1160   params.quantized_activation_min = min_value;
1161   ReluX(params, input_shape, input_data, output_shape, output_data);
1162 }
1163 
1164 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1165 inline void Add(int left_shift, const uint8* input1_data,
1166                 const Dims<4>& input1_dims, int32 input1_offset,
1167                 int32 input1_multiplier, int input1_shift,
1168                 const uint8* input2_data, const Dims<4>& input2_dims,
1169                 int32 input2_offset, int32 input2_multiplier, int input2_shift,
1170                 int32 output_offset, int32 output_multiplier, int output_shift,
1171                 int32 output_activation_min, int32 output_activation_max,
1172                 uint8* output_data, const Dims<4>& output_dims) {
1173   constexpr int kReverseShift = -1;
1174   static_assert(Ac == FusedActivationFunctionType::kNone ||
1175                     Ac == FusedActivationFunctionType::kRelu ||
1176                     Ac == FusedActivationFunctionType::kRelu6 ||
1177                     Ac == FusedActivationFunctionType::kRelu1,
1178                 "");
1179   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1180   if (Ac == FusedActivationFunctionType::kNone) {
1181     TFLITE_DCHECK_EQ(output_activation_min, 0);
1182     TFLITE_DCHECK_EQ(output_activation_max, 255);
1183   }
1184 
1185   tflite::ArithmeticParams op_params;
1186   op_params.left_shift = left_shift;
1187   op_params.input1_offset = input1_offset;
1188   op_params.input1_multiplier = input1_multiplier;
1189   op_params.input1_shift = kReverseShift * input1_shift;
1190   op_params.input2_offset = input2_offset;
1191   op_params.input2_multiplier = input2_multiplier;
1192   op_params.input2_shift = kReverseShift * input2_shift;
1193   op_params.output_offset = output_offset;
1194   op_params.output_multiplier = output_multiplier;
1195   op_params.output_shift = kReverseShift * output_shift;
1196   op_params.quantized_activation_min = output_activation_min;
1197   op_params.quantized_activation_max = output_activation_max;
1198   Add(op_params, DimsToShape(input1_dims), input1_data,
1199       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1200       output_data);
1201 }
1202 
1203 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)1204 void Add(const int32* input1_data, const Dims<4>& input1_dims,
1205          const int32* input2_data, const Dims<4>& input2_dims,
1206          int32* output_data, const Dims<4>& output_dims) {
1207   gemmlowp::ScopedProfilingLabel label("Add/int32");
1208   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
1209 
1210   tflite::ArithmeticParams op_params;
1211   op_params.quantized_activation_min = std::numeric_limits<int32>::min();
1212   op_params.quantized_activation_max = std::numeric_limits<int32>::max();
1213   Add(op_params, DimsToShape(input1_dims), input1_data,
1214       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1215       output_data);
1216 }
1217 
1218 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1219 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
1220                          const Dims<4>& input1_dims, int32 input1_offset,
1221                          int32 input1_multiplier, int input1_shift,
1222                          const uint8* input2_data, const Dims<4>& input2_dims,
1223                          int32 input2_offset, int32 input2_multiplier,
1224                          int input2_shift, int32 output_offset,
1225                          int32 output_multiplier, int output_shift,
1226                          int32 output_activation_min,
1227                          int32 output_activation_max, uint8* output_data,
1228                          const Dims<4>& output_dims) {
1229   constexpr int kReverseShift = -1;
1230   static_assert(Ac == FusedActivationFunctionType::kNone ||
1231                     Ac == FusedActivationFunctionType::kRelu ||
1232                     Ac == FusedActivationFunctionType::kRelu6 ||
1233                     Ac == FusedActivationFunctionType::kRelu1,
1234                 "");
1235   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1236   if (Ac == FusedActivationFunctionType::kNone) {
1237     TFLITE_DCHECK_EQ(output_activation_min, 0);
1238     TFLITE_DCHECK_EQ(output_activation_max, 255);
1239   }
1240 
1241   tflite::ArithmeticParams op_params;
1242   op_params.left_shift = left_shift;
1243   op_params.input1_offset = input1_offset;
1244   op_params.input1_multiplier = input1_multiplier;
1245   op_params.input1_shift = kReverseShift * input1_shift;
1246   op_params.input2_offset = input2_offset;
1247   op_params.input2_multiplier = input2_multiplier;
1248   op_params.input2_shift = kReverseShift * input2_shift;
1249   op_params.output_offset = output_offset;
1250   op_params.output_multiplier = output_multiplier;
1251   op_params.output_shift = kReverseShift * output_shift;
1252   op_params.quantized_activation_min = output_activation_min;
1253   op_params.quantized_activation_max = output_activation_max;
1254   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1255                      DimsToShape(input2_dims), input2_data,
1256                      DimsToShape(output_dims), output_data);
1257 }
1258 
1259 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1260 void Add(const float* input1_data, const Dims<4>& input1_dims,
1261          const float* input2_data, const Dims<4>& input2_dims,
1262          float* output_data, const Dims<4>& output_dims) {
1263   float output_activation_min, output_activation_max;
1264   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1265 
1266   tflite::ArithmeticParams op_params;
1267   op_params.float_activation_min = output_activation_min;
1268   op_params.float_activation_max = output_activation_max;
1269   Add(op_params, DimsToShape(input1_dims), input1_data,
1270       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1271       output_data);
1272 }
1273 
1274 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1275 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1276                   const T* input2_data, const Dims<4>& input2_dims,
1277                   T output_activation_min, T output_activation_max,
1278                   T* output_data, const Dims<4>& output_dims) {
1279   tflite::ArithmeticParams op_params;
1280   op_params.float_activation_min = output_activation_min;
1281   op_params.float_activation_max = output_activation_max;
1282   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1283                      DimsToShape(input2_dims), input2_data,
1284                      DimsToShape(output_dims), output_data);
1285 }
1286 
1287 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1288 inline void BroadcastAddFivefold(
1289     int y0, int y1, int y2, int y3, int y4, int left_shift,
1290     const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
1291     int32 input1_multiplier, int input1_shift, const uint8* input2_data,
1292     const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
1293     int input2_shift, int32 output_offset, int32 output_multiplier,
1294     int output_shift, int32 output_activation_min, int32 output_activation_max,
1295     uint8* output_data, const Dims<4>& output_dims) {
1296   constexpr int kReverseShift = -1;
1297   static_assert(Ac == FusedActivationFunctionType::kNone ||
1298                     Ac == FusedActivationFunctionType::kRelu ||
1299                     Ac == FusedActivationFunctionType::kRelu6 ||
1300                     Ac == FusedActivationFunctionType::kRelu1,
1301                 "");
1302   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1303   if (Ac == FusedActivationFunctionType::kNone) {
1304     TFLITE_DCHECK_EQ(output_activation_min, 0);
1305     TFLITE_DCHECK_EQ(output_activation_max, 255);
1306   }
1307   tflite::ArithmeticParams op_params;
1308   op_params.broadcast_category =
1309       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
1310   op_params.left_shift = left_shift;
1311   op_params.input1_offset = input1_offset;
1312   op_params.input1_multiplier = input1_multiplier;
1313   op_params.input1_shift = kReverseShift * input1_shift;
1314   op_params.input2_offset = input2_offset;
1315   op_params.input2_multiplier = input2_multiplier;
1316   op_params.input2_shift = kReverseShift * input2_shift;
1317   op_params.output_offset = output_offset;
1318   op_params.output_multiplier = output_multiplier;
1319   op_params.output_shift = kReverseShift * output_shift;
1320   op_params.quantized_activation_min = output_activation_min;
1321   op_params.quantized_activation_max = output_activation_max;
1322   op_params.broadcast_shape[4] = y0;
1323   op_params.broadcast_shape[3] = y1;
1324   op_params.broadcast_shape[2] = y2;
1325   op_params.broadcast_shape[1] = y3;
1326   op_params.broadcast_shape[0] = y4;
1327   BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
1328                        DimsToShape(input2_dims), input2_data,
1329                        DimsToShape(output_dims), output_data);
1330 }
1331 
1332 // legacy, for compatibility with old checked-in code
1333 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1334 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1335                   const T* input2_data, const Dims<4>& input2_dims,
1336                   T* output_data, const Dims<4>& output_dims) {
1337   T output_activation_min, output_activation_max;
1338   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1339 
1340   BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
1341                output_activation_min, output_activation_max, output_data,
1342                output_dims);
1343 }
1344 
1345 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)1346 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
1347                 int input1_shift, const int16* input2_data,
1348                 const Dims<4>& input2_dims, int input2_shift,
1349                 int16 output_activation_min, int16 output_activation_max,
1350                 int16* output_data, const Dims<4>& output_dims) {
1351   static_assert(Ac == FusedActivationFunctionType::kNone ||
1352                     Ac == FusedActivationFunctionType::kRelu ||
1353                     Ac == FusedActivationFunctionType::kRelu6 ||
1354                     Ac == FusedActivationFunctionType::kRelu1,
1355                 "");
1356   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1357   if (Ac == FusedActivationFunctionType::kNone) {
1358     TFLITE_DCHECK_EQ(output_activation_min, -32768);
1359     TFLITE_DCHECK_EQ(output_activation_max, 32767);
1360   }
1361 
1362   tflite::ArithmeticParams op_params;
1363   op_params.input1_shift = kReverseShift * input1_shift;
1364   op_params.input2_shift = kReverseShift * input2_shift;
1365   op_params.quantized_activation_min = output_activation_min;
1366   op_params.quantized_activation_max = output_activation_max;
1367   Add(op_params, DimsToShape(input1_dims), input1_data,
1368       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1369       output_data);
1370 }
1371 
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1372 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
1373                 const float* input2_data, const Dims<4>& input2_dims,
1374                 float* output_data, const Dims<4>& output_dims) {
1375   float output_activation_min, output_activation_max;
1376   GetActivationMinMax(FusedActivationFunctionType::kNone,
1377                       &output_activation_min, &output_activation_max);
1378   tflite::ArithmeticParams op_params;
1379   op_params.float_activation_min = output_activation_min;
1380   op_params.float_activation_max = output_activation_max;
1381   Sub(op_params, DimsToShape(input1_dims), input1_data,
1382       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1383       output_data);
1384 }
1385 
1386 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1387 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
1388          const Dims<4>& input2_dims, T* output_data,
1389          const Dims<4>& output_dims) {
1390   tflite::ArithmeticParams op_params;
1391   op_params.quantized_activation_min = std::numeric_limits<T>::min();
1392   op_params.quantized_activation_max = std::numeric_limits<T>::max();
1393   Sub(op_params, DimsToShape(input1_dims), input1_data,
1394       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1395       output_data);
1396 }
1397 
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1398 inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
1399                         int stride_width, int stride_height, int pad_width,
1400                         int pad_height, int kwidth, int kheight,
1401                         float output_activation_min,
1402                         float output_activation_max, float* output_data,
1403                         const Dims<4>& output_dims) {
1404   tflite::PoolParams params;
1405   params.stride_height = stride_height;
1406   params.stride_width = stride_width;
1407   params.filter_height = kheight;
1408   params.filter_width = kwidth;
1409   params.padding_values.height = pad_height;
1410   params.padding_values.width = pad_width;
1411   params.float_activation_min = output_activation_min;
1412   params.float_activation_max = output_activation_max;
1413   AveragePool(params, DimsToShape(input_dims), input_data,
1414               DimsToShape(output_dims), output_data);
1415 }
1416 
1417 // Transitional version that will be moved shortly to legacy_reference_ops, as
1418 // part of RuntimeShape revisions.
BroadcastMul4DSlow(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1419 inline void BroadcastMul4DSlow(const uint8* input1_data,
1420                                const Dims<4>& input1_dims, int32 input1_offset,
1421                                const uint8* input2_data,
1422                                const Dims<4>& input2_dims, int32 input2_offset,
1423                                int32 output_offset, int32 output_multiplier,
1424                                int output_shift, int32 output_activation_min,
1425                                int32 output_activation_max, uint8* output_data,
1426                                const Dims<4>& output_dims) {
1427   tflite::ArithmeticParams op_params;
1428   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1429   op_params.input1_offset = input1_offset;
1430   op_params.input2_offset = input2_offset;
1431   op_params.output_offset = output_offset;
1432   op_params.output_multiplier = output_multiplier;
1433   op_params.output_shift = output_shift;
1434 
1435   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1436                      DimsToShape(input2_dims), input2_data,
1437                      DimsToShape(output_dims), output_data);
1438 }
1439 
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1440 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1441                          int32 input1_offset, const uint8* input2_data,
1442                          const Dims<4>& input2_dims, int32 input2_offset,
1443                          int32 output_offset, int32 output_multiplier,
1444                          int output_shift, int32 output_activation_min,
1445                          int32 output_activation_max, uint8* output_data,
1446                          const Dims<4>& output_dims) {
1447   BroadcastMul4DSlow(
1448       input1_data, input1_dims, input1_offset, input2_data, input2_dims,
1449       input2_offset, output_offset, output_multiplier,
1450       //
1451       kReverseShift * output_shift,
1452       //
1453       output_activation_min, output_activation_max, output_data, output_dims);
1454 }
1455 
1456 // legacy, for compatibility with old checked-in code
1457 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1458 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1459                          int32 input1_offset, const uint8* input2_data,
1460                          const Dims<4>& input2_dims, int32 input2_offset,
1461                          int32 output_offset, int32 output_multiplier,
1462                          int output_shift, int32 output_activation_min,
1463                          int32 output_activation_max, uint8* output_data,
1464                          const Dims<4>& output_dims) {
1465   BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
1466                input2_dims, input2_offset, output_offset, output_multiplier,
1467                output_shift, output_activation_min, output_activation_max,
1468                output_data, output_dims);
1469 }
1470 
1471 // legacy, for compatibility with old checked-in code
1472 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1473 void AveragePool(const float* input_data, const Dims<4>& input_dims,
1474                  int stride_width, int stride_height, int pad_width,
1475                  int pad_height, int kwidth, int kheight, float* output_data,
1476                  const Dims<4>& output_dims) {
1477   float output_activation_min, output_activation_max;
1478   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1479 
1480   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1481               pad_height, kwidth, kheight, output_activation_min,
1482               output_activation_max, output_data, output_dims);
1483 }
1484 
1485 // legacy, for compatibility with old checked-in code
1486 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1487 void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
1488                  int pad_width, int pad_height, int filter_width,
1489                  int filter_height, float* output_data,
1490                  const Dims<4>& output_dims) {
1491   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1492                   filter_width, filter_height, output_data, output_dims);
1493 }
1494 
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1495 inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1496                         int stride_width, int stride_height, int pad_width,
1497                         int pad_height, int filter_width, int filter_height,
1498                         int32 output_activation_min,
1499                         int32 output_activation_max, uint8* output_data,
1500                         const Dims<4>& output_dims) {
1501   tflite::PoolParams params;
1502   params.stride_height = stride_height;
1503   params.stride_width = stride_width;
1504   params.filter_height = filter_height;
1505   params.filter_width = filter_width;
1506   params.padding_values.height = pad_height;
1507   params.padding_values.width = pad_width;
1508   params.quantized_activation_min = output_activation_min;
1509   params.quantized_activation_max = output_activation_max;
1510   AveragePool(params, DimsToShape(input_dims), input_data,
1511               DimsToShape(output_dims), output_data);
1512 }
1513 
1514 // legacy, for compatibility with old checked-in code
1515 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1516 void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1517                  int stride_width, int stride_height, int pad_width,
1518                  int pad_height, int filter_width, int filter_height,
1519                  int32 output_activation_min, int32 output_activation_max,
1520                  uint8* output_data, const Dims<4>& output_dims) {
1521   static_assert(Ac == FusedActivationFunctionType::kNone ||
1522                     Ac == FusedActivationFunctionType::kRelu ||
1523                     Ac == FusedActivationFunctionType::kRelu6 ||
1524                     Ac == FusedActivationFunctionType::kRelu1,
1525                 "");
1526   if (Ac == FusedActivationFunctionType::kNone) {
1527     TFLITE_DCHECK_EQ(output_activation_min, 0);
1528     TFLITE_DCHECK_EQ(output_activation_max, 255);
1529   }
1530   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1531               pad_height, filter_width, filter_height, output_activation_min,
1532               output_activation_max, output_data, output_dims);
1533 }
1534 
1535 // legacy, for compatibility with old checked-in code
1536 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1537 void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1538                  int pad_width, int pad_height, int filter_width,
1539                  int filter_height, int32 output_activation_min,
1540                  int32 output_activation_max, uint8* output_data,
1541                  const Dims<4>& output_dims) {
1542   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1543                   filter_width, filter_height, output_activation_min,
1544                   output_activation_max, output_data, output_dims);
1545 }
1546 
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1547 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
1548                     int stride_width, int stride_height, int pad_width,
1549                     int pad_height, int kwidth, int kheight,
1550                     float output_activation_min, float output_activation_max,
1551                     float* output_data, const Dims<4>& output_dims) {
1552   tflite::PoolParams params;
1553   params.stride_height = stride_height;
1554   params.stride_width = stride_width;
1555   params.filter_height = kheight;
1556   params.filter_width = kwidth;
1557   params.padding_values.height = pad_height;
1558   params.padding_values.width = pad_width;
1559   params.float_activation_min = output_activation_min;
1560   params.float_activation_max = output_activation_max;
1561   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1562           output_data);
1563 }
1564 
1565 // legacy, for compatibility with old checked-in code
1566 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1567 void MaxPool(const float* input_data, const Dims<4>& input_dims,
1568              int stride_width, int stride_height, int pad_width, int pad_height,
1569              int kwidth, int kheight, float* output_data,
1570              const Dims<4>& output_dims) {
1571   float output_activation_min, output_activation_max;
1572   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1573   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1574           pad_height, kwidth, kheight, output_activation_min,
1575           output_activation_max, output_data, output_dims);
1576 }
1577 
1578 // legacy, for compatibility with old checked-in code
1579 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1580 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
1581              int pad_width, int pad_height, int filter_width, int filter_height,
1582              float* output_data, const Dims<4>& output_dims) {
1583   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1584               filter_width, filter_height, output_data, output_dims);
1585 }
1586 
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1587 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1588                     int stride_width, int stride_height, int pad_width,
1589                     int pad_height, int filter_width, int filter_height,
1590                     int32 output_activation_min, int32 output_activation_max,
1591                     uint8* output_data, const Dims<4>& output_dims) {
1592   PoolParams params;
1593   params.stride_height = stride_height;
1594   params.stride_width = stride_width;
1595   params.filter_height = filter_height;
1596   params.filter_width = filter_width;
1597   params.padding_values.height = pad_height;
1598   params.padding_values.width = pad_width;
1599   params.quantized_activation_min = output_activation_min;
1600   params.quantized_activation_max = output_activation_max;
1601   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1602           output_data);
1603 }
1604 
1605 // legacy, for compatibility with old checked-in code
1606 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1607 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1608              int stride_width, int stride_height, int pad_width, int pad_height,
1609              int filter_width, int filter_height, int32 output_activation_min,
1610              int32 output_activation_max, uint8* output_data,
1611              const Dims<4>& output_dims) {
1612   static_assert(Ac == FusedActivationFunctionType::kNone ||
1613                     Ac == FusedActivationFunctionType::kRelu ||
1614                     Ac == FusedActivationFunctionType::kRelu6 ||
1615                     Ac == FusedActivationFunctionType::kRelu1,
1616                 "");
1617   if (Ac == FusedActivationFunctionType::kNone) {
1618     TFLITE_DCHECK_EQ(output_activation_min, 0);
1619     TFLITE_DCHECK_EQ(output_activation_max, 255);
1620   }
1621   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1622           pad_height, filter_width, filter_height, output_activation_min,
1623           output_activation_max, output_data, output_dims);
1624 }
1625 
1626 // legacy, for compatibility with old checked-in code
1627 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1628 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1629              int pad_width, int pad_height, int filter_width, int filter_height,
1630              int32 output_activation_min, int32 output_activation_max,
1631              uint8* output_data, const Dims<4>& output_dims) {
1632   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1633               filter_width, filter_height, output_activation_min,
1634               output_activation_max, output_data, output_dims);
1635 }
1636 
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1637 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
1638                    int stride_width, int stride_height, int pad_width,
1639                    int pad_height, int filter_width, int filter_height,
1640                    float output_activation_min, float output_activation_max,
1641                    float* output_data, const Dims<4>& output_dims) {
1642   PoolParams params;
1643   params.stride_height = stride_height;
1644   params.stride_width = stride_width;
1645   params.filter_height = filter_height;
1646   params.filter_width = filter_width;
1647   params.padding_values.height = pad_height;
1648   params.padding_values.width = pad_width;
1649   params.float_activation_min = output_activation_min;
1650   params.float_activation_max = output_activation_max;
1651   L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1652          output_data);
1653 }
1654 
1655 // legacy, for compatibility with old checked-in code
1656 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1657 void L2Pool(const float* input_data, const Dims<4>& input_dims,
1658             int stride_width, int stride_height, int pad_width, int pad_height,
1659             int filter_width, int filter_height, float* output_data,
1660             const Dims<4>& output_dims) {
1661   float output_activation_min, output_activation_max;
1662   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1663   L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
1664          pad_height, filter_width, filter_height, output_activation_min,
1665          output_activation_max, output_data, output_dims);
1666 }
1667 
1668 // legacy, for compatibility with old checked-in code
1669 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1670 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
1671             int pad_width, int pad_height, int filter_width, int filter_height,
1672             float* output_data, const Dims<4>& output_dims) {
1673   L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1674              filter_width, filter_height, output_data, output_dims);
1675 }
1676 
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)1677 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
1678                     float beta, float* output_data,
1679                     const Dims<4>& output_dims) {
1680   Softmax(input_data, DimsToShape(input_dims), beta, output_data,
1681           DimsToShape(output_dims));
1682 }
1683 
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1684 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
1685                     int32 input_beta_multiplier, int32 input_beta_left_shift,
1686                     int diff_min, uint8* output_data,
1687                     const Dims<4>& output_dims) {
1688   Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
1689           input_beta_left_shift, diff_min, output_data,
1690           DimsToShape(output_dims));
1691 }
1692 
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1693 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
1694                        float* output_data, const Dims<4>& output_dims) {
1695   LogSoftmax(input_data, DimsToShape(input_dims), output_data,
1696              DimsToShape(output_dims));
1697 }
1698 
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1699 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
1700                        int32 input_multiplier, int32 input_left_shift,
1701                        int32 reverse_scaling_divisor,
1702                        int32 reverse_scaling_right_shift, int diff_min,
1703                        uint8* output_data, const Dims<4>& output_dims) {
1704   LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
1705              input_left_shift, reverse_scaling_divisor,
1706              reverse_scaling_right_shift, diff_min, output_data,
1707              DimsToShape(output_dims));
1708 }
1709 
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1710 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
1711                      float* output_data, const Dims<4>& output_dims) {
1712   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1713            output_data);
1714 }
1715 
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1716 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
1717                      int32 input_zero_point, int32 input_range_radius,
1718                      int32 input_multiplier, int input_left_shift,
1719                      uint8* output_data, const Dims<4>& output_dims) {
1720   Logistic(input_data, DimsToShape(input_dims), input_zero_point,
1721            input_range_radius, input_multiplier, input_left_shift, output_data,
1722            DimsToShape(output_dims));
1723 }
1724 
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)1725 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
1726                      int16* output_data, const Dims<4>& output_dims) {
1727   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1728            output_data);
1729 }
1730 
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1731 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
1732                  float* output_data, const Dims<4>& output_dims) {
1733   Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1734        output_data);
1735 }
1736 
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1737 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
1738                  int32 input_zero_point, int32 input_range_radius,
1739                  int32 input_multiplier, int input_left_shift,
1740                  uint8* output_data, const Dims<4>& output_dims) {
1741   Tanh(input_data, DimsToShape(input_dims), input_zero_point,
1742        input_range_radius, input_multiplier, input_left_shift, output_data,
1743        DimsToShape(output_dims));
1744 }
1745 
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)1746 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
1747                  int input_left_shift, int16* output_data,
1748                  const Dims<4>& output_dims) {
1749   Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
1750        DimsToShape(output_dims));
1751 }
1752 
1753 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1754 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
1755                          int block_size, T* output_data,
1756                          const Dims<4>& output_dims) {
1757   tflite::DepthToSpaceParams op_params;
1758   op_params.block_size = block_size;
1759 
1760   DepthToSpace(op_params, DimsToShape(input_dims), input_data,
1761                DimsToShape(output_dims), output_data);
1762 }
1763 
1764 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1765 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
1766                          int block_size, T* output_data,
1767                          const Dims<4>& output_dims) {
1768   tflite::SpaceToDepthParams op_params;
1769   op_params.block_size = block_size;
1770 
1771   SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
1772                DimsToShape(output_dims), output_data);
1773 }
1774 
1775 template <typename T>
Mul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1776 inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
1777                 const T* input2_data, const Dims<4>& input2_dims,
1778                 T output_activation_min, T output_activation_max,
1779                 T* output_data, const Dims<4>& output_dims) {
1780   tflite::ArithmeticParams op_params;
1781   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1782 
1783   Mul(op_params, DimsToShape(input1_dims), input1_data,
1784       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1785       output_data);
1786 }
1787 
1788 // legacy, for compatibility with old checked-in code
1789 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1790 void Mul(const float* input1_data, const Dims<4>& input1_dims,
1791          const float* input2_data, const Dims<4>& input2_dims,
1792          float* output_data, const Dims<4>& output_dims) {
1793   float output_activation_min, output_activation_max;
1794   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1795 
1796   tflite::ArithmeticParams op_params;
1797   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1798 
1799   Mul(op_params, DimsToShape(input1_dims), input1_data,
1800       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1801       output_data);
1802 }
1803 
1804 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1805 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1806                   const T* input2_data, const Dims<4>& input2_dims,
1807                   T output_activation_min, T output_activation_max,
1808                   T* output_data, const Dims<4>& output_dims) {
1809   tflite::ArithmeticParams op_params;
1810   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1811 
1812   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1813                      DimsToShape(input2_dims), input2_data,
1814                      DimsToShape(output_dims), output_data);
1815 }
1816 
1817 // legacy, for compatibility with old checked-in code
1818 template <FusedActivationFunctionType Ac, typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1819 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1820                   const T* input2_data, const Dims<4>& input2_dims,
1821                   T* output_data, const Dims<4>& output_dims) {
1822   T output_activation_min, output_activation_max;
1823   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1824 
1825   tflite::ArithmeticParams op_params;
1826   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1827 
1828   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1829                      DimsToShape(input2_dims), input2_data,
1830                      DimsToShape(output_dims), output_data);
1831 }
1832 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)1833 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1834                 const int16* input2_data, const Dims<4>& input2_dims,
1835                 int16* output_data, const Dims<4>& output_dims) {
1836   tflite::ArithmeticParams op_params;
1837   // No params in this version.
1838 
1839   Mul(op_params, DimsToShape(input1_dims), input1_data,
1840       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1841       output_data);
1842 }
1843 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1844 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1845                 const int16* input2_data, const Dims<4>& input2_dims,
1846                 int32 output_offset, int32 output_activation_min,
1847                 int32 output_activation_max, uint8* output_data,
1848                 const Dims<4>& output_dims) {
1849   tflite::ArithmeticParams op_params;
1850   op_params.quantized_activation_min = output_activation_min;
1851   op_params.quantized_activation_max = output_activation_max;
1852   op_params.output_offset = output_offset;
1853 
1854   Mul(op_params, DimsToShape(input1_dims), input1_data,
1855       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1856       output_data);
1857 }
1858 
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)1859 inline void LocalResponseNormalization(const float* input_data,
1860                                        const Dims<4>& input_dims, int range,
1861                                        float bias, float alpha, float beta,
1862                                        float* output_data,
1863                                        const Dims<4>& output_dims) {
1864   tflite::LocalResponseNormalizationParams op_params;
1865   op_params.range = range;
1866   op_params.bias = bias;
1867   op_params.alpha = alpha;
1868   op_params.beta = beta;
1869 
1870   LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
1871                              DimsToShape(output_dims), output_data);
1872 }
1873 
1874 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)1875 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
1876           const Dims<4>& output_dims) {
1877   Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1878        output_data);
1879 }
1880 
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1881 inline void Floor(const float* input_data, const Dims<4>& input_dims,
1882                   float* output_data, const Dims<4>& output_dims) {
1883   Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1884         output_data);
1885 }
1886 
1887 template <typename T>
ResizeBilinear(const T * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,T * output_data,const Dims<4> & output_dims,bool align_corners)1888 inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
1889                            const int32* output_size_data,
1890                            const Dims<4>& output_size_dims, T* output_data,
1891                            const Dims<4>& output_dims, bool align_corners) {
1892   tflite::ResizeBilinearParams op_params;
1893   op_params.align_corners = align_corners;
1894   ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
1895                  DimsToShape(output_size_dims), output_size_data,
1896                  DimsToShape(output_dims), output_data);
1897 }
1898 
1899 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)1900 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
1901                            const int32* output_size_data,
1902                            const Dims<4>& output_size_dims, float* output_data,
1903                            const Dims<4>& output_dims) {
1904   ResizeBilinear<float>(input_data, input_dims, output_size_data,
1905                         output_size_dims, output_data, output_dims,
1906                         /*align_corners=*/false);
1907 }
1908 
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)1909 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
1910                            const int32* output_size_data,
1911                            const Dims<4>& output_size_dims, uint8* output_data,
1912                            const Dims<4>& output_dims) {
1913   ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
1914                         output_size_dims, output_data, output_dims,
1915                         /*align_corners=*/false);
1916 }
1917 
1918 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)1919 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
1920                            const int32* block_shape_data,
1921                            const Dims<4>& block_shape_dims,
1922                            const int32* paddings_data,
1923                            const Dims<4>& paddings_dims, T* output_data,
1924                            const Dims<4>& output_dims,
1925                            const int32_t pad_value) {
1926   tflite::SpaceToBatchParams op_params;
1927   op_params.output_offset = pad_value;
1928 
1929   SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
1930                  DimsToShape(block_shape_dims), block_shape_data,
1931                  DimsToShape(paddings_dims), paddings_data,
1932                  DimsToShape(output_dims), output_data);
1933 }
1934 
1935 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims)1936 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
1937                            const int32* block_shape_data,
1938                            const Dims<4>& block_shape_dims,
1939                            const int32* paddings_data,
1940                            const Dims<4>& paddings_dims, T* output_data,
1941                            const Dims<4>& output_dims) {
1942   tflite::SpaceToBatchParams op_params;
1943   op_params.output_offset = 0;
1944 
1945   SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
1946                  DimsToShape(block_shape_dims), block_shape_data,
1947                  DimsToShape(paddings_dims), paddings_data,
1948                  DimsToShape(output_dims), output_data);
1949 }
1950 
1951 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)1952 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
1953                            const int32* block_shape_data,
1954                            const Dims<4>& block_shape_dims,
1955                            const int32* crops_data, const Dims<4>& crops_dims,
1956                            T* output_data, const Dims<4>& output_dims) {
1957   BatchToSpaceND(DimsToShape(input_dims), input_data,
1958                  DimsToShape(block_shape_dims), block_shape_data,
1959                  DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
1960                  output_data);
1961 }
1962 
1963 // Legacy signature, function covered both Pad and PadV2.
1964 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)1965 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
1966                   const std::vector<int>& left_paddings,
1967                   const std::vector<int>& right_paddings, T* output_data,
1968                   const Dims<4>& output_dims, const T pad_value) {
1969   TFLITE_DCHECK_EQ(left_paddings.size(), 4);
1970   TFLITE_DCHECK_EQ(right_paddings.size(), 4);
1971   tflite::PadParams op_params;
1972   op_params.left_padding_count = 4;
1973   op_params.right_padding_count = 4;
1974   for (int i = 0; i < 4; ++i) {
1975     op_params.left_padding[i] = left_paddings[3 - i];
1976     op_params.right_padding[i] = right_paddings[3 - i];
1977   }
1978   // SetFloatOrInt(pad_value, &op_params.pad_value);
1979   const T pad_value_copy = pad_value;
1980 
1981   Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
1982       DimsToShape(output_dims), output_data);
1983 }
1984 
1985 // Old Pad that calls legacy PadV2.
1986 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)1987 inline void Pad(const T* input_data, const Dims<4>& input_dims,
1988                 const std::vector<int>& left_paddings,
1989                 const std::vector<int>& right_paddings, T* output_data,
1990                 const Dims<4>& output_dims, const int32_t pad_value) {
1991   const T converted_pad_value = static_cast<T>(pad_value);
1992   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
1993            output_dims, converted_pad_value);
1994 }
1995 
1996 // Old Pad that only padded with 0.
1997 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)1998 inline void Pad(const T* input_data, const Dims<4>& input_dims,
1999                 const std::vector<int>& left_paddings,
2000                 const std::vector<int>& right_paddings, T* output_data,
2001                 const Dims<4>& output_dims) {
2002   const T pad_value = static_cast<T>(0);
2003   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2004            output_dims, pad_value);
2005 }
2006 
2007 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2008 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
2009                        const T* input2_data, T* output_data,
2010                        const Dims<4>& output_dims) {
2011   Minimum(DimsToShape(input1_dims), input1_data, input2_data,
2012           DimsToShape(output_dims), output_data);
2013 }
2014 
2015 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2016 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
2017                        const T* input2_data, T* output_data,
2018                        const Dims<4>& output_dims) {
2019   Maximum(DimsToShape(input1_dims), input1_data, input2_data,
2020           DimsToShape(output_dims), output_data);
2021 }
2022 
2023 template <typename T, typename Op>
TensorFlowMaximumMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims,Op op)2024 void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
2025                               const T* input2_data, const Dims<4>& input2_dims,
2026                               T* output_data, const Dims<4>& output_dims,
2027                               Op op) {
2028   MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
2029                                 DimsToShape(input2_dims), input2_data,
2030                                 DimsToShape(output_dims), output_data, op);
2031 }
2032 
2033 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)2034 void ArgMax(const T3* axis, const T1* input_data,
2035             const tflite::Dims<4>& input_dims, T2* output_data,
2036             const tflite::Dims<4>& output_dims) {
2037   // Assumes the input always has 4 dimensions, and therefore,
2038   // output always has three dimensions.
2039   auto output_shape = RuntimeShape(
2040       {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
2041   // Another way to interpret this is that output_dims.sizes[4] is always 1.
2042   TFLITE_DCHECK_EQ(output_shape.FlatSize(),
2043                    DimsToShape(output_dims).FlatSize());
2044   // Legacy path only supported this.
2045   TFLITE_DCHECK_EQ(axis[0], 3);
2046   ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
2047             output_data, std::greater<T1>());
2048 }
2049 
2050 template <typename T1, typename T2, typename T3, typename Cmp>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const Cmp & cmp)2051 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
2052                T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
2053   ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
2054             output_data, cmp);
2055 }
2056 
2057 template <typename T>
Pow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2058 inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
2059                 const T* input2_data, const Dims<4>& input2_dims,
2060                 T* output_data, const Dims<4>& output_dims) {
2061   Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2062       input2_data, DimsToShape(output_dims), output_data);
2063 }
2064 
2065 template <typename T>
BroadcastPow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2066 inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
2067                          const T* input2_data, const Dims<4>& input2_dims,
2068                          T* output_data, const Dims<4>& output_dims) {
2069   BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
2070                      DimsToShape(input2_dims), input2_data,
2071                      DimsToShape(output_dims), output_data);
2072 }
2073 
Logical(const bool * input1_data,const Dims<4> & input1_dims,const bool * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims,const std::function<bool (bool,bool)> & func)2074 inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
2075                     const bool* input2_data, const Dims<4>& input2_dims,
2076                     bool* output_data, const Dims<4>& output_dims,
2077                     const std::function<bool(bool, bool)>& func) {
2078   Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2079           input2_data, DimsToShape(output_dims), output_data, func);
2080 }
2081 
BroadcastLogical(const bool * input1_data,const Dims<4> & input1_dims,const bool * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims,const std::function<bool (bool,bool)> & func)2082 inline void BroadcastLogical(const bool* input1_data,
2083                              const Dims<4>& input1_dims,
2084                              const bool* input2_data,
2085                              const Dims<4>& input2_dims, bool* output_data,
2086                              const Dims<4>& output_dims,
2087                              const std::function<bool(bool, bool)>& func) {
2088   BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
2089                          DimsToShape(input2_dims), input2_data,
2090                          DimsToShape(output_dims), output_data, func);
2091 }
2092 
2093 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2094 template <typename R, typename T1, typename T2>
BroadcastBinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2095 inline void BroadcastBinaryFunction(const T1* input1_data,
2096                                     const Dims<4>& input1_dims,
2097                                     const T2* input2_data,
2098                                     const Dims<4>& input2_dims, R* output_data,
2099                                     const Dims<4>& output_dims,
2100                                     R (*func)(T1, T2)) {
2101   BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
2102                           DimsToShape(input2_dims), input2_data,
2103                           DimsToShape(output_dims), output_data, func);
2104 }
2105 
2106 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2107 template <typename R, typename T1, typename T2>
BinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2108 inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
2109                            const T2* input2_data, const Dims<4>& input2_dims,
2110                            R* output_data, const Dims<4>& output_dims,
2111                            R (*func)(T1, T2)) {
2112   BinaryFunction(DimsToShape(input1_dims), input1_data,
2113                  DimsToShape(input2_dims), input2_data,
2114                  DimsToShape(output_dims), output_data, func);
2115 }
2116 
2117 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)2118 inline void Slice(const T* input_data, const Dims<4>& input_dims,
2119                   const std::vector<int>& begin, const std::vector<int>& size,
2120                   T* output_data, const Dims<4>& output_dims) {
2121   tflite::SliceParams op_params;
2122   op_params.begin_count = 4;
2123   op_params.size_count = 4;
2124   for (int i = 0; i < 4; ++i) {
2125     op_params.begin[i] = begin[3 - i];
2126     op_params.size[i] = size[3 - i];
2127   }
2128 
2129   Slice(op_params, DimsToShape(input_dims), input_data,
2130         DimsToShape(output_dims), output_data);
2131 }
2132 
2133 }  // namespace reference_ops
2134 }  // namespace tflite
2135 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
2136