1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/legacy_types.h"
23 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
24 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27
28 namespace tflite {
29
30 namespace reference_ops {
31
32 static constexpr int kDepthwiseReverseShift = -1;
33
ShapeFromDims(const tflite::Dims<4> & dims,RuntimeShape * shape)34 inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
35 shape->BuildFrom(
36 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
37 }
38
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)39 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
40 const float* filter_data, const Dims<4>& filter_dims,
41 const float* bias_data, const Dims<4>& bias_dims,
42 int stride_width, int stride_height,
43 int dilation_width_factor, int dilation_height_factor,
44 int pad_width, int pad_height, int depth_multiplier,
45 float output_activation_min,
46 float output_activation_max, float* output_data,
47 const Dims<4>& output_dims) {
48 tflite::DepthwiseParams op_params;
49 // Padding type is ignored, but still set.
50 op_params.padding_type = PaddingType::kSame;
51 op_params.padding_values.width = pad_width;
52 op_params.padding_values.height = pad_height;
53 op_params.stride_width = stride_width;
54 op_params.stride_height = stride_height;
55 op_params.dilation_width_factor = dilation_width_factor;
56 op_params.dilation_height_factor = dilation_height_factor;
57 op_params.depth_multiplier = depth_multiplier;
58 op_params.float_activation_min = output_activation_min;
59 op_params.float_activation_max = output_activation_max;
60
61 DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
62 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
63 bias_data, DimsToShape(output_dims), output_data);
64 }
65
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)66 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
67 const float* filter_data, const Dims<4>& filter_dims,
68 const float* bias_data, const Dims<4>& bias_dims,
69 int stride_width, int stride_height, int pad_width,
70 int pad_height, int depth_multiplier,
71 float output_activation_min,
72 float output_activation_max, float* output_data,
73 const Dims<4>& output_dims) {
74 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
75 bias_dims, stride_width, stride_height, 1, 1, pad_width,
76 pad_height, depth_multiplier, output_activation_min,
77 output_activation_max, output_data, output_dims);
78 }
79
80 // Legacy, for compatibility with old checked-in code.
81 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)82 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
83 const float* filter_data, const Dims<4>& filter_dims,
84 const float* bias_data, const Dims<4>& bias_dims,
85 int stride_width, int stride_height, int pad_width,
86 int pad_height, int depth_multiplier, float* output_data,
87 const Dims<4>& output_dims) {
88 float output_activation_min, output_activation_max;
89 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
90 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
91 bias_dims, stride_width, stride_height, pad_width, pad_height,
92 depth_multiplier, output_activation_min, output_activation_max,
93 output_data, output_dims);
94 }
95
96 // Legacy, for compatibility with old checked-in code.
97 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)98 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
99 const float* filter_data, const Dims<4>& filter_dims,
100 const float* bias_data, const Dims<4>& bias_dims, int stride,
101 int pad_width, int pad_height, int depth_multiplier,
102 float* output_data, const Dims<4>& output_dims) {
103 DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
104 bias_dims, stride, stride, pad_width, pad_height,
105 depth_multiplier, output_data, output_dims);
106 }
107
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)108 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
109 int32 input_offset, const uint8* filter_data,
110 const Dims<4>& filter_dims, int32 filter_offset,
111 const int32* bias_data, const Dims<4>& bias_dims,
112 int stride_width, int stride_height,
113 int dilation_width_factor, int dilation_height_factor,
114 int pad_width, int pad_height, int depth_multiplier,
115 int32 output_offset, int32 output_multiplier,
116 int output_shift, int32 output_activation_min,
117 int32 output_activation_max, uint8* output_data,
118 const Dims<4>& output_dims) {
119 tflite::DepthwiseParams op_params;
120 // Padding type is ignored, but still set.
121 op_params.padding_type = PaddingType::kSame;
122 op_params.padding_values.width = pad_width;
123 op_params.padding_values.height = pad_height;
124 op_params.stride_width = stride_width;
125 op_params.stride_height = stride_height;
126 op_params.dilation_width_factor = dilation_width_factor;
127 op_params.dilation_height_factor = dilation_height_factor;
128 op_params.depth_multiplier = depth_multiplier;
129 op_params.quantized_activation_min = output_activation_min;
130 op_params.quantized_activation_max = output_activation_max;
131 op_params.input_offset = input_offset;
132 op_params.weights_offset = filter_offset;
133 op_params.output_offset = output_offset;
134 op_params.output_multiplier = output_multiplier;
135 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
136 op_params.output_shift = kDepthwiseReverseShift * output_shift;
137
138 DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
139 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
140 bias_data, DimsToShape(output_dims), output_data);
141 }
142
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)143 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
144 int32 input_offset, const uint8* filter_data,
145 const Dims<4>& filter_dims, int32 filter_offset,
146 const int32* bias_data, const Dims<4>& bias_dims,
147 int stride_width, int stride_height, int pad_width,
148 int pad_height, int depth_multiplier,
149 int32 output_offset, int32 output_multiplier,
150 int output_shift, int32 output_activation_min,
151 int32 output_activation_max, uint8* output_data,
152 const Dims<4>& output_dims) {
153 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
154 filter_offset, bias_data, bias_dims, stride_width,
155 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
156 output_offset, output_multiplier, output_shift,
157 output_activation_min, output_activation_max, output_data,
158 output_dims);
159 }
160
161 // Legacy, for compatibility with old checked-in code.
162 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)163 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
164 int32 input_offset, const uint8* filter_data,
165 const Dims<4>& filter_dims, int32 filter_offset,
166 const int32* bias_data, const Dims<4>& bias_dims,
167 int stride_width, int stride_height, int pad_width,
168 int pad_height, int depth_multiplier, int32 output_offset,
169 int32 output_multiplier, int output_shift,
170 int32 output_activation_min, int32 output_activation_max,
171 uint8* output_data, const Dims<4>& output_dims) {
172 if (Ac == FusedActivationFunctionType::kNone) {
173 TFLITE_DCHECK_EQ(output_activation_min, 0);
174 TFLITE_DCHECK_EQ(output_activation_max, 255);
175 }
176 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
177 filter_offset, bias_data, bias_dims, stride_width,
178 stride_height, pad_width, pad_height, depth_multiplier,
179 output_offset, output_multiplier, output_shift,
180 output_activation_min, output_activation_max, output_data,
181 output_dims);
182 }
183
184 // Legacy, for compatibility with old checked-in code.
185 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)186 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
187 int32 input_offset, const uint8* filter_data,
188 const Dims<4>& filter_dims, int32 filter_offset,
189 const int32* bias_data, const Dims<4>& bias_dims, int stride,
190 int pad_width, int pad_height, int depth_multiplier,
191 int32 output_offset, int32 output_multiplier,
192 int output_shift, int32 output_activation_min,
193 int32 output_activation_max, uint8* output_data,
194 const Dims<4>& output_dims) {
195 DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
196 filter_dims, filter_offset, bias_data, bias_dims, stride,
197 stride, pad_width, pad_height, depth_multiplier,
198 output_offset, output_multiplier, output_shift,
199 output_activation_min, output_activation_max, output_data,
200 output_dims);
201 }
202
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)203 inline void Conv(const float* input_data, const Dims<4>& input_dims,
204 const float* filter_data, const Dims<4>& filter_dims,
205 const float* bias_data, const Dims<4>& bias_dims,
206 int stride_width, int stride_height, int dilation_width_factor,
207 int dilation_height_factor, int pad_width, int pad_height,
208 float output_activation_min, float output_activation_max,
209 float* output_data, const Dims<4>& output_dims,
210 float* im2col_data, const Dims<4>& im2col_dims) {
211 tflite::ConvParams op_params;
212 // Padding type is ignored, but still set.
213 op_params.padding_type = PaddingType::kSame;
214 op_params.padding_values.width = pad_width;
215 op_params.padding_values.height = pad_height;
216 op_params.stride_width = stride_width;
217 op_params.stride_height = stride_height;
218 op_params.dilation_width_factor = dilation_width_factor;
219 op_params.dilation_height_factor = dilation_height_factor;
220 op_params.float_activation_min = output_activation_min;
221 op_params.float_activation_max = output_activation_max;
222
223 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
224 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
225 output_data, DimsToShape(im2col_dims), im2col_data);
226 }
227
228 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)229 void Conv(const float* input_data, const Dims<4>& input_dims,
230 const float* filter_data, const Dims<4>& filter_dims,
231 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
232 int stride_height, int dilation_width_factor,
233 int dilation_height_factor, int pad_width, int pad_height,
234 float* output_data, const Dims<4>& output_dims, float* im2col_data,
235 const Dims<4>& im2col_dims) {
236 float output_activation_min, output_activation_max;
237 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
238 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
239 stride_width, stride_height, dilation_width_factor,
240 dilation_height_factor, pad_width, pad_height, output_activation_min,
241 output_activation_max, output_data, output_dims, im2col_data,
242 im2col_dims);
243 }
244
245 // legacy, for compatibility with old checked-in code
246 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)247 void Conv(const float* input_data, const Dims<4>& input_dims,
248 const float* filter_data, const Dims<4>& filter_dims,
249 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
250 int stride_height, int pad_width, int pad_height, float* output_data,
251 const Dims<4>& output_dims, float* im2col_data,
252 const Dims<4>& im2col_dims) {
253 float output_activation_min, output_activation_max;
254 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
255 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
256 stride_width, stride_height, 1, 1, pad_width, pad_height,
257 output_activation_min, output_activation_max, output_data, output_dims,
258 im2col_data, im2col_dims);
259 }
260
261 // legacy, for compatibility with old checked-in code
262 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)263 void Conv(const float* input_data, const Dims<4>& input_dims,
264 const float* filter_data, const Dims<4>& filter_dims,
265 const float* bias_data, const Dims<4>& bias_dims, int stride,
266 int pad_width, int pad_height, float* output_data,
267 const Dims<4>& output_dims, float* im2col_data,
268 const Dims<4>& im2col_dims) {
269 Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
270 bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
271 output_dims, im2col_data, im2col_dims);
272 }
273
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)274 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
275 int32 input_offset, const uint8* filter_data,
276 const Dims<4>& filter_dims, int32 filter_offset,
277 const int32* bias_data, const Dims<4>& bias_dims,
278 int stride_width, int stride_height, int dilation_width_factor,
279 int dilation_height_factor, int pad_width, int pad_height,
280 int32 output_offset, int32 output_multiplier, int output_shift,
281 int32 output_activation_min, int32 output_activation_max,
282 uint8* output_data, const Dims<4>& output_dims,
283 uint8* im2col_data, const Dims<4>& im2col_dims,
284 gemmlowp::GemmContext* gemm_context) {
285 tflite::ConvParams op_params;
286 // Padding type is ignored, but still set.
287 op_params.padding_type = PaddingType::kSame;
288 op_params.padding_values.width = pad_width;
289 op_params.padding_values.height = pad_height;
290 op_params.stride_width = stride_width;
291 op_params.stride_height = stride_height;
292 op_params.dilation_width_factor = dilation_width_factor;
293 op_params.dilation_height_factor = dilation_height_factor;
294 op_params.input_offset = input_offset;
295 op_params.weights_offset = filter_offset;
296 op_params.output_offset = output_offset;
297 op_params.output_multiplier = output_multiplier;
298 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
299 op_params.output_shift = kReverseShift * output_shift;
300 op_params.quantized_activation_min = output_activation_min;
301 op_params.quantized_activation_max = output_activation_max;
302
303 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
304 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
305 output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
306 }
307
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)308 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
309 int32 input_offset, const uint8* filter_data,
310 const Dims<4>& filter_dims, int32 filter_offset,
311 const int32* bias_data, const Dims<4>& bias_dims,
312 int stride_width, int stride_height, int pad_width,
313 int pad_height, int32 output_offset, int32 output_multiplier,
314 int output_shift, int32 output_activation_min,
315 int32 output_activation_max, uint8* output_data,
316 const Dims<4>& output_dims, uint8* im2col_data,
317 const Dims<4>& im2col_dims,
318 gemmlowp::GemmContext* gemm_context) {
319 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
320 filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
321 pad_width, pad_height, output_offset, output_multiplier, output_shift,
322 output_activation_min, output_activation_max, output_data, output_dims,
323 im2col_data, im2col_dims, gemm_context);
324 }
325
326 // legacy, for compatibility with old checked-in code
327 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)328 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
329 int32 input_offset, const uint8* filter_data,
330 const Dims<4>& filter_dims, int32 filter_offset,
331 const int32* bias_data, const Dims<4>& bias_dims,
332 int stride_width, int stride_height, int pad_width,
333 int pad_height, int32 output_offset, int32 output_multiplier,
334 int output_shift, int32 output_activation_min,
335 int32 output_activation_max, uint8* output_data,
336 const Dims<4>& output_dims, uint8* im2col_data,
337 const Dims<4>& im2col_dims,
338 gemmlowp::GemmContext* gemm_context) {
339 static_assert(Ac == FusedActivationFunctionType::kNone ||
340 Ac == FusedActivationFunctionType::kRelu ||
341 Ac == FusedActivationFunctionType::kRelu6 ||
342 Ac == FusedActivationFunctionType::kRelu1,
343 "");
344 if (Ac == FusedActivationFunctionType::kNone) {
345 TFLITE_DCHECK_EQ(output_activation_min, 0);
346 TFLITE_DCHECK_EQ(output_activation_max, 255);
347 }
348 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
349 filter_offset, bias_data, bias_dims, stride_width, stride_height,
350 pad_width, pad_height, output_offset, output_multiplier, output_shift,
351 output_activation_min, output_activation_max, output_data, output_dims,
352 im2col_data, im2col_dims, gemm_context);
353 }
354
355 // legacy, for compatibility with old checked-in code
356 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemm_context)357 void Conv(const uint8* input_data, const Dims<4>& input_dims,
358 int32 input_offset, const uint8* filter_data,
359 const Dims<4>& filter_dims, int32 filter_offset,
360 const int32* bias_data, const Dims<4>& bias_dims, int stride,
361 int pad_width, int pad_height, int32 output_offset,
362 int32 output_multiplier, int output_shift,
363 int32 output_activation_min, int32 output_activation_max,
364 uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
365 const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
366 Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
367 filter_offset, bias_data, bias_dims, stride, stride, pad_width,
368 pad_height, output_offset, output_multiplier, output_shift,
369 output_activation_min, output_activation_max, output_data,
370 output_dims, im2col_data, im2col_dims, gemm_context);
371 }
372
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)373 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
374 const float* filter_data, const Dims<4>& filter_dims,
375 int stride_width, int stride_height, int pad_width,
376 int pad_height, float* output_data,
377 const Dims<4>& output_dims, float* im2col_data,
378 const Dims<4>& im2col_dims) {
379 tflite::ConvParams op_params;
380 // Padding type is ignored, but still set.
381 op_params.padding_type = PaddingType::kSame;
382 op_params.padding_values.width = pad_width;
383 op_params.padding_values.height = pad_height;
384 op_params.stride_width = stride_width;
385 op_params.stride_height = stride_height;
386
387 TransposeConv(op_params, DimsToShape(input_dims), input_data,
388 DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
389 output_data, DimsToShape(im2col_dims), im2col_data);
390 }
391
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)392 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
393 const float* weights_data,
394 const Dims<4>& weights_dims, const float* bias_data,
395 const Dims<4>& bias_dims,
396 float output_activation_min,
397 float output_activation_max, float* output_data,
398 const Dims<4>& output_dims) {
399 tflite::FullyConnectedParams op_params;
400 op_params.float_activation_min = output_activation_min;
401 op_params.float_activation_max = output_activation_max;
402
403 FullyConnected(op_params, DimsToShape(input_dims), input_data,
404 DimsToShape(weights_dims), weights_data,
405 DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
406 output_data);
407 }
408
409 // legacy, for compatibility with old checked-in code
410 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)411 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
412 const float* weights_data, const Dims<4>& weights_dims,
413 const float* bias_data, const Dims<4>& bias_dims,
414 float* output_data, const Dims<4>& output_dims) {
415 float output_activation_min, output_activation_max;
416 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
417 FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
418 bias_dims, output_activation_min, output_activation_max,
419 output_data, output_dims);
420 }
421
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemm_context)422 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
423 int32 input_offset, const uint8* filter_data,
424 const Dims<4>& filter_dims, int32 filter_offset,
425 const int32* bias_data, const Dims<4>& bias_dims,
426 int32 output_offset, int32 output_multiplier,
427 int output_shift, int32 output_activation_min,
428 int32 output_activation_max, uint8* output_data,
429 const Dims<4>& output_dims,
430 gemmlowp::GemmContext* gemm_context) {
431 tflite::FullyConnectedParams op_params;
432 op_params.input_offset = input_offset;
433 op_params.weights_offset = filter_offset;
434 op_params.output_offset = output_offset;
435 op_params.output_multiplier = output_multiplier;
436 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
437 op_params.output_shift = kReverseShift * output_shift;
438 op_params.quantized_activation_min = output_activation_min;
439 op_params.quantized_activation_max = output_activation_max;
440
441 FullyConnected(op_params, DimsToShape(input_dims), input_data,
442 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
443 bias_data, DimsToShape(output_dims), output_data,
444 gemm_context);
445 }
446
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemm_context)447 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
448 int32 input_offset, const uint8* filter_data,
449 const Dims<4>& filter_dims, int32 filter_offset,
450 const int32* bias_data, const Dims<4>& bias_dims,
451 int32 output_offset, int32 output_multiplier,
452 int output_shift, int32 output_activation_min,
453 int32 output_activation_max, int16* output_data,
454 const Dims<4>& output_dims,
455 gemmlowp::GemmContext* gemm_context) {
456 tflite::FullyConnectedParams op_params;
457 op_params.input_offset = input_offset;
458 op_params.weights_offset = filter_offset;
459 op_params.output_offset = output_offset;
460 op_params.output_multiplier = output_multiplier;
461 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
462 op_params.output_shift = kReverseShift * output_shift;
463 op_params.quantized_activation_min = output_activation_min;
464 op_params.quantized_activation_max = output_activation_max;
465
466 FullyConnected(op_params, DimsToShape(input_dims), input_data,
467 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
468 bias_data, DimsToShape(output_dims), output_data,
469 gemm_context);
470 }
471
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemm_context)472 inline void ShuffledFullyConnected(
473 const uint8* input_data, const Dims<4>& input_dims,
474 const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
475 const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
476 int output_shift, int32 output_activation_min, int32 output_activation_max,
477 int16* output_data, const Dims<4>& output_dims,
478 uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
479 tflite::FullyConnectedParams op_params;
480 op_params.output_multiplier = output_multiplier;
481 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
482 op_params.output_shift = kReverseShift * output_shift;
483 op_params.quantized_activation_min = output_activation_min;
484 op_params.quantized_activation_max = output_activation_max;
485
486 ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
487 DimsToShape(weights_dims), shuffled_weights_data,
488 DimsToShape(bias_dims), bias_data,
489 DimsToShape(output_dims), output_data,
490 shuffled_input_workspace_data, gemm_context);
491 }
492
493 // legacy, for compatibility with old checked-in code
494 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemm_context)495 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
496 int32 input_offset, const uint8* filter_data,
497 const Dims<4>& filter_dims, int32 filter_offset,
498 const int32* bias_data, const Dims<4>& bias_dims,
499 int32 output_offset, int32 output_multiplier,
500 int output_shift, int32 output_activation_min,
501 int32 output_activation_max, uint8* output_data,
502 const Dims<4>& output_dims,
503 gemmlowp::GemmContext* gemm_context) {
504 static_assert(Ac == FusedActivationFunctionType::kNone ||
505 Ac == FusedActivationFunctionType::kRelu ||
506 Ac == FusedActivationFunctionType::kRelu6 ||
507 Ac == FusedActivationFunctionType::kRelu1,
508 "");
509 if (Ac == FusedActivationFunctionType::kNone) {
510 TFLITE_DCHECK_EQ(output_activation_min, 0);
511 TFLITE_DCHECK_EQ(output_activation_max, 255);
512 }
513 FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
514 filter_offset, bias_data, bias_dims, output_offset,
515 output_multiplier, output_shift, output_activation_min,
516 output_activation_max, output_data, output_dims, gemm_context);
517 }
518
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)519 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
520 const float* prev_activ_data,
521 const Dims<4>& prev_activ_dims, const float* weights_data,
522 const Dims<4>& weights_dims, const float* bias_data,
523 const Dims<4>& bias_dims, const float* prev_state_data,
524 const Dims<4>& prev_state_dims, float* output_state_data,
525 const Dims<4>& output_state_dims, float* output_activ_data,
526 const Dims<4>& output_activ_dims, float* concat_temp_data,
527 const Dims<4>& concat_temp_dims, float* activ_temp_data,
528 const Dims<4>& activ_temp_dims) {
529 tflite::LstmCellParams op_params;
530 // Float LSTM cell does not need parameters to be set: leave untouched.
531
532 LstmCell(op_params, DimsToShape(input_dims), input_data,
533 DimsToShape(prev_activ_dims), prev_activ_data,
534 DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
535 bias_data, DimsToShape(prev_state_dims), prev_state_data,
536 DimsToShape(output_state_dims), output_state_data,
537 DimsToShape(output_activ_dims), output_activ_data,
538 DimsToShape(concat_temp_dims), concat_temp_data,
539 DimsToShape(activ_temp_dims), activ_temp_data);
540 }
541
542 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemm_context)543 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
544 const uint8* prev_activ_data_uint8,
545 const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
546 const Dims<4>& weights_dims, const int32* bias_data_int32,
547 const Dims<4>& bias_dims, const int16* prev_state_data_int16,
548 const Dims<4>& prev_state_dims, int16* output_state_data_int16,
549 const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
550 const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
551 const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
552 const Dims<4>& activ_temp_dims, int32 weights_zero_point,
553 int32 accum_multiplier, int accum_shift,
554 gemmlowp::GemmContext* gemm_context) {
555 tflite::LstmCellParams op_params;
556 op_params.weights_zero_point = weights_zero_point;
557 op_params.accum_multiplier = accum_multiplier;
558 op_params.accum_shift = accum_shift;
559
560 LstmCell<StateIntegerBits>(
561 op_params, DimsToShape(input_dims), input_data_uint8,
562 DimsToShape(prev_activ_dims), prev_activ_data_uint8,
563 DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
564 bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
565 DimsToShape(output_state_dims), output_state_data_int16,
566 DimsToShape(output_activ_dims), output_activ_data_uint8,
567 DimsToShape(concat_temp_dims), concat_temp_data_uint8,
568 DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
569 }
570
571 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)572 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
573 const T* input2_data, const Dims<4>& input2_dims,
574 T output_activation_min, T output_activation_max,
575 T* output_data, const Dims<4>& output_dims) {
576 tflite::ArithmeticParams op_params;
577 SetActivationParams(output_activation_min, output_activation_max, &op_params);
578
579 BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
580 DimsToShape(input2_dims), input2_data,
581 DimsToShape(output_dims), output_data);
582 }
583
584 template <typename T>
Div(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)585 inline void Div(const T* input1_data, const Dims<4>& input1_dims,
586 const T* input2_data, const Dims<4>& input2_dims,
587 T output_activation_min, T output_activation_max,
588 T* output_data, const Dims<4>& output_dims) {
589 tflite::ArithmeticParams op_params;
590 SetActivationParams(output_activation_min, output_activation_max, &op_params);
591
592 Div(op_params, DimsToShape(input1_dims), input1_data,
593 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
594 output_data);
595 }
596
597 template <FusedActivationFunctionType Ac, typename Scalar>
Concatenation(int concat_dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)598 inline void Concatenation(int concat_dim, const Scalar* const* input_data,
599 const Dims<4>* const* input_dims, int inputs_count,
600 Scalar* output_data, const Dims<4>& output_dims) {
601 // For now we don't have a model with a Concatenation with fused activation.
602 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
603
604 std::vector<RuntimeShape> input_shapes(inputs_count);
605 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
606 for (int i = 0; i < inputs_count; ++i) {
607 ShapeFromDims(*input_dims[i], &input_shapes[i]);
608 input_shapes_indirect[i] = &input_shapes[i];
609 }
610 tflite::ConcatenationParams op_params;
611 op_params.axis = 3 - concat_dim;
612 op_params.inputs_count = inputs_count;
613
614 Concatenation(op_params, input_shapes_indirect.data(), input_data,
615 DimsToShape(output_dims), output_data);
616 }
617
Concatenation(int concat_dim,const uint8 * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,uint8 * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)618 inline void Concatenation(int concat_dim, const uint8* const* input_data,
619 const Dims<4>* const* input_dims,
620 const int32* input_zeropoint,
621 const float* input_scale, int inputs_count,
622 uint8* output_data, const Dims<4>& output_dims,
623 const int32 output_zeropoint,
624 const float output_scale) {
625 std::vector<RuntimeShape> input_shapes(inputs_count);
626 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
627 for (int i = 0; i < inputs_count; ++i) {
628 ShapeFromDims(*input_dims[i], &input_shapes[i]);
629 input_shapes_indirect[i] = &input_shapes[i];
630 }
631 tflite::ConcatenationParams op_params;
632 op_params.axis = 3 - concat_dim;
633 op_params.input_zeropoint = input_zeropoint;
634 op_params.input_scale = input_scale;
635 op_params.inputs_count = inputs_count;
636 op_params.output_zeropoint = output_zeropoint;
637 op_params.output_scale = output_scale;
638
639 ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
640 DimsToShape(output_dims), output_data);
641 }
642
643 template <FusedActivationFunctionType Ac, typename Scalar>
DepthConcatenation(const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)644 void DepthConcatenation(const Scalar* const* input_data,
645 const Dims<4>* const* input_dims, int inputs_count,
646 Scalar* output_data, const Dims<4>& output_dims) {
647 // For now we don't have a model with a Concatenation with fused activation.
648 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
649
650 std::vector<RuntimeShape> input_shapes(inputs_count);
651 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
652 for (int i = 0; i < inputs_count; ++i) {
653 ShapeFromDims(*input_dims[i], &input_shapes[i]);
654 input_shapes_indirect[i] = &input_shapes[i];
655 }
656 tflite::ConcatenationParams op_params;
657 op_params.inputs_count = inputs_count;
658
659 DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
660 DimsToShape(output_dims), output_data);
661 }
662
663 template <typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int axis,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)664 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
665 int axis, int outputs_count, Scalar* const* output_data,
666 const Dims<4>* const* output_dims) {
667 std::vector<RuntimeShape> output_shapes(outputs_count);
668 std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
669 for (int i = 0; i < outputs_count; ++i) {
670 ShapeFromDims(*output_dims[i], &output_shapes[i]);
671 output_shapes_indirect[i] = &output_shapes[i];
672 }
673 tflite::SplitParams op_params;
674 op_params.axis = 3 - axis;
675 op_params.num_split = outputs_count;
676
677 Split(op_params, DimsToShape(input_dims), input_data,
678 output_shapes_indirect.data(), output_data);
679 }
680
681 template <FusedActivationFunctionType Ac, typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)682 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
683 int outputs_count, Scalar* const* output_data,
684 const Dims<4>* const* output_dims) {
685 TFLITE_DCHECK_GE(outputs_count, 1);
686 for (int i = 0; i < outputs_count; i++) {
687 /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
688 /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
689 /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
690 }
691 // For now we don't have a model with a Split with fused activation.
692 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
693
694 TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
695 output_data, output_dims);
696 }
697
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)698 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
699 float beta, float* output_data,
700 const RuntimeShape& output_shape) {
701 SoftmaxParams params;
702 params.beta = beta;
703 Softmax(params, input_shape, input_data, output_shape, output_data);
704 }
705
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)706 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
707 int32 input_beta_multiplier, int32 input_beta_left_shift,
708 int diff_min, uint8* output_data,
709 const RuntimeShape& output_shape) {
710 SoftmaxParams params;
711 params.input_multiplier = input_beta_multiplier;
712 params.input_left_shift = input_beta_left_shift;
713 params.diff_min = diff_min;
714 Softmax(params, input_shape, input_data, output_shape, output_data);
715 }
716
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)717 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
718 float* output_data, const RuntimeShape& output_shape) {
719 SoftmaxParams params;
720 // No params currently used for float LogSoftmax.
721 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
722 }
723
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)724 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
725 int32 input_multiplier, int32 input_left_shift,
726 int32 reverse_scaling_divisor,
727 int32 reverse_scaling_right_shift, int diff_min,
728 uint8* output_data, const RuntimeShape& output_shape) {
729 SoftmaxParams params;
730 params.input_multiplier = input_multiplier;
731 params.input_left_shift = input_left_shift;
732 params.reverse_scaling_divisor = reverse_scaling_divisor;
733 params.reverse_scaling_right_shift = reverse_scaling_right_shift;
734 params.diff_min = diff_min;
735 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
736 }
737
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)738 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
739 int32 input_zero_point, int32 input_range_radius,
740 int32 input_multiplier, int input_left_shift,
741 uint8* output_data, const RuntimeShape& output_shape) {
742 LogisticParams params;
743 params.input_zero_point = input_zero_point;
744 params.input_range_radius = input_range_radius;
745 params.input_multiplier = input_multiplier;
746 params.input_left_shift = input_left_shift;
747 Logistic(params, input_shape, input_data, output_shape, output_data);
748 }
749
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)750 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
751 const RuntimeShape& output_shape, int16* output_data) {
752 LogisticParams params;
753 // No params currently needed by int16 Logistic.
754 Logistic(params, input_shape, input_data, output_shape, output_data);
755 }
756
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)757 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
758 int32 input_zero_point, int32 input_range_radius,
759 int32 input_multiplier, int input_left_shift,
760 uint8* output_data, const RuntimeShape& output_shape) {
761 TanhParams params;
762 params.input_zero_point = input_zero_point;
763 params.input_range_radius = input_range_radius;
764 params.input_multiplier = input_multiplier;
765 params.input_left_shift = input_left_shift;
766 Tanh(params, input_shape, input_data, output_shape, output_data);
767 }
768
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)769 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
770 int input_left_shift, int16* output_data,
771 const RuntimeShape& output_shape) {
772 TanhParams params;
773 params.input_left_shift = input_left_shift;
774 Tanh(params, input_shape, input_data, output_shape, output_data);
775 }
776
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)777 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
778 int32 zero_point, double scale, float* output_data,
779 const Dims<4>& output_dims) {
780 tflite::DequantizationParams op_params;
781 op_params.zero_point = zero_point;
782 op_params.scale = scale;
783
784 Dequantize(op_params, DimsToShape(input_dims), input_data,
785 DimsToShape(output_dims), output_data);
786 }
787
FakeQuant(const float * input_data,const Dims<4> & input_dims,float rmin,float rmax,int num_bits,float * output_data,const Dims<4> & output_dims)788 inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
789 float rmin, float rmax, int num_bits, float* output_data,
790 const Dims<4>& output_dims) {
791 tflite::FakeQuantParams op_params;
792 op_params.num_bits = num_bits;
793 op_params.minmax.min = rmin;
794 op_params.minmax.max = rmax;
795
796 FakeQuant(op_params, DimsToShape(input_dims), input_data,
797 DimsToShape(output_dims), output_data);
798 }
799
800 template <typename T>
Gather(const T * input_data,const Dims<4> & input_dims,int input_rank,const int32 * coords_data,const Dims<4> & coords_dims,T * output_data,const Dims<4> & output_dims)801 inline void Gather(const T* input_data, const Dims<4>& input_dims,
802 int input_rank, const int32* coords_data,
803 const Dims<4>& coords_dims, T* output_data,
804 const Dims<4>& output_dims) {
805 tflite::GatherParams op_params;
806 op_params.axis = 4 - input_rank;
807
808 Gather(op_params, DimsToShape(input_dims), input_data,
809 DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
810 output_data);
811 }
812
LegacyReverseBits32(uint32 n)813 inline uint32 LegacyReverseBits32(uint32 n) {
814 n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
815 n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
816 n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
817 return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
818 ((n & 0xFF000000) >> 24));
819 }
820
StridedSliceReverseIndices(tflite::StridedSliceParams * p)821 inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
822 TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
823 TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
824
825 std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
826 std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
827 std::reverse(p->strides, p->strides + p->strides_count);
828
829 p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
830 (32 - p->start_indices_count);
831 p->ellipsis_mask =
832 LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
833 (32 - p->start_indices_count);
834 p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
835 (32 - p->start_indices_count);
836 p->new_axis_mask =
837 LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
838 (32 - p->start_indices_count);
839 p->shrink_axis_mask =
840 LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
841 (32 - p->start_indices_count);
842 }
843
844 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)845 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
846 int begin_mask, int end_mask, int shrink_axis_mask,
847 const std::vector<int>& start_indices,
848 const std::vector<int>& stop_indices,
849 const std::vector<int>& strides, T* output_data,
850 const Dims<4>& output_dims) {
851 TFLITE_DCHECK_EQ(start_indices.size(), 4);
852 auto op_params = strided_slice::BuildStridedSliceParams(
853 begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
854 strides);
855 StridedSliceReverseIndices(&op_params);
856
857 StridedSlice(op_params, DimsToShape(input_dims), input_data,
858 DimsToShape(output_dims), output_data);
859 }
860
861 template <typename T>
Mean(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & reduction_indices,T * output_data,const Dims<4> & output_dims)862 inline void Mean(const T* input_data, const Dims<4>& input_dims,
863 const std::vector<int>& reduction_indices, T* output_data,
864 const Dims<4>& output_dims) {
865 tflite::MeanParams op_params;
866 op_params.axis_count = reduction_indices.size();
867 for (int i = 0; i < op_params.axis_count; ++i) {
868 op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
869 }
870
871 Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
872 output_data);
873 }
874
875 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)876 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
877 const Dims<4>& output_dims, const int* permuted_axes) {
878 TransposeParams params;
879 params.perm_count = 4;
880 for (int i = 0; i < 4; ++i) {
881 params.perm[i] = 3 - permuted_axes[3 - i];
882 }
883 Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
884 output);
885 }
886
887 template <typename T, ComparisonFn<T> F>
Comparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)888 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
889 const T* input2_data, const Dims<4>& input2_dims,
890 bool* output_data, const Dims<4>& output_dims) {
891 ComparisonParams op_params;
892 // No parameters needed.
893 ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
894 DimsToShape(input2_dims), input2_data,
895 DimsToShape(output_dims), output_data);
896 }
897
898 template <typename T, ComparisonFn<int32> F>
Comparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)899 inline void Comparison(int left_shift, const T* input1_data,
900 const Dims<4>& input1_dims, int32 input1_offset,
901 int32 input1_multiplier, int input1_shift,
902 const T* input2_data, const Dims<4>& input2_dims,
903 int32 input2_offset, int32 input2_multiplier,
904 int input2_shift, bool* output_data,
905 const Dims<4>& output_dims) {
906 tflite::ComparisonParams op_params;
907 op_params.left_shift = left_shift;
908 op_params.input1_offset = input1_offset;
909 op_params.input1_multiplier = input1_multiplier;
910 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
911 op_params.input1_shift = kReverseShift * input1_shift;
912 op_params.input2_offset = input2_offset;
913 op_params.input2_multiplier = input2_multiplier;
914 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
915 op_params.input2_shift = kReverseShift * input2_shift;
916
917 ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
918 DimsToShape(input2_dims), input2_data,
919 DimsToShape(output_dims), output_data);
920 }
921
922 template <typename T, ComparisonFn<T> F>
BroadcastComparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)923 inline void BroadcastComparison(const T* input1_data,
924 const Dims<4>& input1_dims,
925 const T* input2_data,
926 const Dims<4>& input2_dims, bool* output_data,
927 const Dims<4>& output_dims) {
928 ComparisonParams op_params;
929 // No parameters needed.
930 BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
931 input1_data, DimsToShape(input2_dims),
932 input2_data, DimsToShape(output_dims),
933 output_data);
934 }
935
936 template <typename T, ComparisonFn<int32> F>
BroadcastComparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)937 inline void BroadcastComparison(int left_shift, const T* input1_data,
938 const Dims<4>& input1_dims, int32 input1_offset,
939 int32 input1_multiplier, int input1_shift,
940 const T* input2_data,
941 const Dims<4>& input2_dims, int32 input2_offset,
942 int32 input2_multiplier, int input2_shift,
943 bool* output_data, const Dims<4>& output_dims) {
944 ComparisonParams op_params;
945
946 op_params.left_shift = left_shift;
947 op_params.input1_offset = input1_offset;
948 op_params.input1_multiplier = input1_multiplier;
949 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
950 op_params.input1_shift = kReverseShift * input1_shift;
951 op_params.input2_offset = input2_offset;
952 op_params.input2_multiplier = input2_multiplier;
953 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
954 op_params.input2_shift = kReverseShift * input2_shift;
955
956 BroadcastComparison4DSlowWithScaling<T, F>(
957 op_params, DimsToShape(input1_dims), input1_data,
958 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
959 output_data);
960 }
961
962 #define TFLITE_LEGACY_COMPARISON_OP(name) \
963 template <typename T> \
964 inline void name(const T* input1_data, const Dims<4>& input1_dims, \
965 const T* input2_data, const Dims<4>& input2_dims, \
966 bool* output_data, const Dims<4>& output_dims) { \
967 gemmlowp::ScopedProfilingLabel label(#name); \
968 Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
969 input2_dims, output_data, output_dims); \
970 } \
971 template <typename T> \
972 inline void name( \
973 int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
974 int32 input1_offset, int32 input1_multiplier, int input1_shift, \
975 const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
976 int32 input2_multiplier, int input2_shift, bool* output_data, \
977 const Dims<4>& output_dims) { \
978 gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
979 Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
980 input1_offset, input1_multiplier, input1_shift, \
981 input2_data, input2_dims, input2_offset, \
982 input2_multiplier, input2_shift, output_data, \
983 output_dims); \
984 } \
985 template <typename T> \
986 inline void Broadcast##name( \
987 const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
988 const Dims<4>& input2_dims, bool* output_data, \
989 const Dims<4>& output_dims) { \
990 gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
991 BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
992 input2_dims, output_data, output_dims); \
993 } \
994 template <typename T> \
995 inline void Broadcast##name( \
996 int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
997 int32 input1_offset, int32 input1_multiplier, int input1_shift, \
998 const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
999 int32 input2_multiplier, int input2_shift, bool* output_data, \
1000 const Dims<4>& output_dims) { \
1001 gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
1002 BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
1003 input1_offset, input1_multiplier, \
1004 input1_shift, input2_data, input2_dims, \
1005 input2_offset, input2_multiplier, \
1006 input2_shift, output_data, output_dims); \
1007 }
1008 TFLITE_LEGACY_COMPARISON_OP(Equal);
1009 TFLITE_LEGACY_COMPARISON_OP(NotEqual);
1010 TFLITE_LEGACY_COMPARISON_OP(Greater);
1011 TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
1012 TFLITE_LEGACY_COMPARISON_OP(Less);
1013 TFLITE_LEGACY_COMPARISON_OP(LessEqual);
1014 #undef TFLITE_LEGACY_COMPARISON_OP
1015
1016 template <typename D, typename T>
Select(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1017 inline void Select(const D* input_condition_data,
1018 const Dims<4>& input_condition_dims, const T* input_x_data,
1019 const Dims<4>& input_x_dims, const T* input_y_data,
1020 const Dims<4>& input_y_dims, T* output_data,
1021 const Dims<4>& output_dims) {
1022 Select(DimsToShape(input_condition_dims), input_condition_data,
1023 DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
1024 input_y_data, DimsToShape(output_dims), output_data);
1025 }
1026
1027 template <typename D, typename T>
RankOneSelect(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1028 inline void RankOneSelect(const D* input_condition_data,
1029 const Dims<4>& input_condition_dims,
1030 const T* input_x_data, const Dims<4>& input_x_dims,
1031 const T* input_y_data, const Dims<4>& input_y_dims,
1032 T* output_data, const Dims<4>& output_dims) {
1033 RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
1034 DimsToShape(input_x_dims), input_x_data,
1035 DimsToShape(input_y_dims), input_y_data,
1036 DimsToShape(output_dims), output_data);
1037 }
1038
1039 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,T * output_data,const Dims<4> & output_dims,bool value_is_scalar)1040 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1041 const T* values, T default_value, T* output_data,
1042 const Dims<4>& output_dims, bool value_is_scalar) {
1043 SparseToDense(indices, values, default_value, value_is_scalar,
1044 DimsToShape(output_dims), output_data);
1045 }
1046
1047 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)1048 void Pack(int dim, const Scalar* const* input_data,
1049 const Dims<4>* const* input_dims, int inputs_count,
1050 Scalar* output_data, const Dims<4>& output_dims) {
1051 std::vector<RuntimeShape> input_shapes(inputs_count);
1052 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1053 for (int i = 0; i < inputs_count; ++i) {
1054 ShapeFromDims(*input_dims[i], &input_shapes[i]);
1055 input_shapes_indirect[i] = &input_shapes[i];
1056 }
1057 tflite::PackParams op_params;
1058 op_params.axis = 3 - dim;
1059 op_params.inputs_count = inputs_count;
1060
1061 Pack(op_params, input_shapes_indirect.data(), input_data,
1062 DimsToShape(output_dims), output_data);
1063 }
1064
1065 template <typename Scalar>
Unpack(int axis,const Scalar * input_data,const Dims<4> & input_dims,int dimensions,int outputs_count,Scalar * const * output_datas,const Dims<4> & output_dims)1066 void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
1067 int dimensions, int outputs_count, Scalar* const* output_datas,
1068 const Dims<4>& output_dims) {
1069 tflite::UnpackParams op_params;
1070 op_params.axis = 3 - axis;
1071 op_params.num_split = outputs_count;
1072
1073 Unpack(op_params, DimsToShape(input_dims), input_data,
1074 DimsToShape(output_dims), output_datas);
1075 }
1076
1077 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,Scalar * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)1078 void Pack(int dim, const Scalar* const* input_data,
1079 const Dims<4>* const* input_dims, const int32* input_zeropoint,
1080 const float* input_scale, int inputs_count, Scalar* output_data,
1081 const Dims<4>& output_dims, const int32 output_zeropoint,
1082 const float output_scale) {
1083 std::vector<RuntimeShape> input_shapes(inputs_count);
1084 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1085 for (int i = 0; i < inputs_count; ++i) {
1086 ShapeFromDims(*input_dims[i], &input_shapes[i]);
1087 input_shapes_indirect[i] = &input_shapes[i];
1088 }
1089 tflite::PackParams op_params;
1090 op_params.axis = 3 - dim;
1091 op_params.input_zeropoint = input_zeropoint;
1092 op_params.input_scale = input_scale;
1093 op_params.inputs_count = inputs_count;
1094 op_params.output_zeropoint = output_zeropoint;
1095 op_params.output_scale = output_scale;
1096
1097 PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
1098 DimsToShape(output_dims), output_data);
1099 }
1100
1101 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)1102 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
1103 float* output_data, const RuntimeShape& output_shape) {
1104 static_assert(Ac == FusedActivationFunctionType::kNone, "");
1105 tflite::L2NormalizationParams op_params;
1106 // No params need to be set for float.
1107
1108 L2Normalization(op_params, input_shape, input_data, output_shape,
1109 output_data);
1110 }
1111
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)1112 inline void L2Normalization(const uint8* input_data,
1113 const RuntimeShape& input_shape,
1114 int32 input_zero_point, uint8* output_data,
1115 const RuntimeShape& output_shape) {
1116 tflite::L2NormalizationParams op_params;
1117 op_params.input_zero_point = input_zero_point;
1118
1119 L2Normalization(op_params, input_shape, input_data, output_shape,
1120 output_data);
1121 }
1122
1123 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1124 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
1125 float* output_data, const Dims<4>& output_dims) {
1126 L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
1127 DimsToShape(output_dims));
1128 }
1129
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)1130 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
1131 int32 input_zero_point, uint8* output_data,
1132 const Dims<4>& output_dims) {
1133 L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
1134 output_data, DimsToShape(output_dims));
1135 }
1136
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1137 inline void Relu(const float* input_data, const Dims<4>& input_dims,
1138 float* output_data, const Dims<4>& output_dims) {
1139 Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1140 output_data);
1141 }
1142
Relu1(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1143 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
1144 float* output_data, const Dims<4>& output_dims) {
1145 Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1146 output_data);
1147 }
1148
Relu6(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1149 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
1150 float* output_data, const Dims<4>& output_dims) {
1151 Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1152 output_data);
1153 }
1154
ReluX(uint8 min_value,uint8 max_value,const uint8 * input_data,const RuntimeShape & input_shape,uint8 * output_data,const RuntimeShape & output_shape)1155 inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
1156 const RuntimeShape& input_shape, uint8* output_data,
1157 const RuntimeShape& output_shape) {
1158 tflite::ActivationParams params;
1159 params.quantized_activation_max = max_value;
1160 params.quantized_activation_min = min_value;
1161 ReluX(params, input_shape, input_data, output_shape, output_data);
1162 }
1163
1164 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1165 inline void Add(int left_shift, const uint8* input1_data,
1166 const Dims<4>& input1_dims, int32 input1_offset,
1167 int32 input1_multiplier, int input1_shift,
1168 const uint8* input2_data, const Dims<4>& input2_dims,
1169 int32 input2_offset, int32 input2_multiplier, int input2_shift,
1170 int32 output_offset, int32 output_multiplier, int output_shift,
1171 int32 output_activation_min, int32 output_activation_max,
1172 uint8* output_data, const Dims<4>& output_dims) {
1173 constexpr int kReverseShift = -1;
1174 static_assert(Ac == FusedActivationFunctionType::kNone ||
1175 Ac == FusedActivationFunctionType::kRelu ||
1176 Ac == FusedActivationFunctionType::kRelu6 ||
1177 Ac == FusedActivationFunctionType::kRelu1,
1178 "");
1179 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1180 if (Ac == FusedActivationFunctionType::kNone) {
1181 TFLITE_DCHECK_EQ(output_activation_min, 0);
1182 TFLITE_DCHECK_EQ(output_activation_max, 255);
1183 }
1184
1185 tflite::ArithmeticParams op_params;
1186 op_params.left_shift = left_shift;
1187 op_params.input1_offset = input1_offset;
1188 op_params.input1_multiplier = input1_multiplier;
1189 op_params.input1_shift = kReverseShift * input1_shift;
1190 op_params.input2_offset = input2_offset;
1191 op_params.input2_multiplier = input2_multiplier;
1192 op_params.input2_shift = kReverseShift * input2_shift;
1193 op_params.output_offset = output_offset;
1194 op_params.output_multiplier = output_multiplier;
1195 op_params.output_shift = kReverseShift * output_shift;
1196 op_params.quantized_activation_min = output_activation_min;
1197 op_params.quantized_activation_max = output_activation_max;
1198 Add(op_params, DimsToShape(input1_dims), input1_data,
1199 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1200 output_data);
1201 }
1202
1203 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)1204 void Add(const int32* input1_data, const Dims<4>& input1_dims,
1205 const int32* input2_data, const Dims<4>& input2_dims,
1206 int32* output_data, const Dims<4>& output_dims) {
1207 gemmlowp::ScopedProfilingLabel label("Add/int32");
1208 TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
1209
1210 tflite::ArithmeticParams op_params;
1211 op_params.quantized_activation_min = std::numeric_limits<int32>::min();
1212 op_params.quantized_activation_max = std::numeric_limits<int32>::max();
1213 Add(op_params, DimsToShape(input1_dims), input1_data,
1214 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1215 output_data);
1216 }
1217
1218 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1219 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
1220 const Dims<4>& input1_dims, int32 input1_offset,
1221 int32 input1_multiplier, int input1_shift,
1222 const uint8* input2_data, const Dims<4>& input2_dims,
1223 int32 input2_offset, int32 input2_multiplier,
1224 int input2_shift, int32 output_offset,
1225 int32 output_multiplier, int output_shift,
1226 int32 output_activation_min,
1227 int32 output_activation_max, uint8* output_data,
1228 const Dims<4>& output_dims) {
1229 constexpr int kReverseShift = -1;
1230 static_assert(Ac == FusedActivationFunctionType::kNone ||
1231 Ac == FusedActivationFunctionType::kRelu ||
1232 Ac == FusedActivationFunctionType::kRelu6 ||
1233 Ac == FusedActivationFunctionType::kRelu1,
1234 "");
1235 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1236 if (Ac == FusedActivationFunctionType::kNone) {
1237 TFLITE_DCHECK_EQ(output_activation_min, 0);
1238 TFLITE_DCHECK_EQ(output_activation_max, 255);
1239 }
1240
1241 tflite::ArithmeticParams op_params;
1242 op_params.left_shift = left_shift;
1243 op_params.input1_offset = input1_offset;
1244 op_params.input1_multiplier = input1_multiplier;
1245 op_params.input1_shift = kReverseShift * input1_shift;
1246 op_params.input2_offset = input2_offset;
1247 op_params.input2_multiplier = input2_multiplier;
1248 op_params.input2_shift = kReverseShift * input2_shift;
1249 op_params.output_offset = output_offset;
1250 op_params.output_multiplier = output_multiplier;
1251 op_params.output_shift = kReverseShift * output_shift;
1252 op_params.quantized_activation_min = output_activation_min;
1253 op_params.quantized_activation_max = output_activation_max;
1254 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1255 DimsToShape(input2_dims), input2_data,
1256 DimsToShape(output_dims), output_data);
1257 }
1258
1259 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1260 void Add(const float* input1_data, const Dims<4>& input1_dims,
1261 const float* input2_data, const Dims<4>& input2_dims,
1262 float* output_data, const Dims<4>& output_dims) {
1263 float output_activation_min, output_activation_max;
1264 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1265
1266 tflite::ArithmeticParams op_params;
1267 op_params.float_activation_min = output_activation_min;
1268 op_params.float_activation_max = output_activation_max;
1269 Add(op_params, DimsToShape(input1_dims), input1_data,
1270 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1271 output_data);
1272 }
1273
1274 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1275 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1276 const T* input2_data, const Dims<4>& input2_dims,
1277 T output_activation_min, T output_activation_max,
1278 T* output_data, const Dims<4>& output_dims) {
1279 tflite::ArithmeticParams op_params;
1280 op_params.float_activation_min = output_activation_min;
1281 op_params.float_activation_max = output_activation_max;
1282 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1283 DimsToShape(input2_dims), input2_data,
1284 DimsToShape(output_dims), output_data);
1285 }
1286
1287 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1288 inline void BroadcastAddFivefold(
1289 int y0, int y1, int y2, int y3, int y4, int left_shift,
1290 const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
1291 int32 input1_multiplier, int input1_shift, const uint8* input2_data,
1292 const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
1293 int input2_shift, int32 output_offset, int32 output_multiplier,
1294 int output_shift, int32 output_activation_min, int32 output_activation_max,
1295 uint8* output_data, const Dims<4>& output_dims) {
1296 constexpr int kReverseShift = -1;
1297 static_assert(Ac == FusedActivationFunctionType::kNone ||
1298 Ac == FusedActivationFunctionType::kRelu ||
1299 Ac == FusedActivationFunctionType::kRelu6 ||
1300 Ac == FusedActivationFunctionType::kRelu1,
1301 "");
1302 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1303 if (Ac == FusedActivationFunctionType::kNone) {
1304 TFLITE_DCHECK_EQ(output_activation_min, 0);
1305 TFLITE_DCHECK_EQ(output_activation_max, 255);
1306 }
1307 tflite::ArithmeticParams op_params;
1308 op_params.broadcast_category =
1309 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
1310 op_params.left_shift = left_shift;
1311 op_params.input1_offset = input1_offset;
1312 op_params.input1_multiplier = input1_multiplier;
1313 op_params.input1_shift = kReverseShift * input1_shift;
1314 op_params.input2_offset = input2_offset;
1315 op_params.input2_multiplier = input2_multiplier;
1316 op_params.input2_shift = kReverseShift * input2_shift;
1317 op_params.output_offset = output_offset;
1318 op_params.output_multiplier = output_multiplier;
1319 op_params.output_shift = kReverseShift * output_shift;
1320 op_params.quantized_activation_min = output_activation_min;
1321 op_params.quantized_activation_max = output_activation_max;
1322 op_params.broadcast_shape[4] = y0;
1323 op_params.broadcast_shape[3] = y1;
1324 op_params.broadcast_shape[2] = y2;
1325 op_params.broadcast_shape[1] = y3;
1326 op_params.broadcast_shape[0] = y4;
1327 BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
1328 DimsToShape(input2_dims), input2_data,
1329 DimsToShape(output_dims), output_data);
1330 }
1331
1332 // legacy, for compatibility with old checked-in code
1333 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1334 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1335 const T* input2_data, const Dims<4>& input2_dims,
1336 T* output_data, const Dims<4>& output_dims) {
1337 T output_activation_min, output_activation_max;
1338 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1339
1340 BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
1341 output_activation_min, output_activation_max, output_data,
1342 output_dims);
1343 }
1344
1345 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)1346 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
1347 int input1_shift, const int16* input2_data,
1348 const Dims<4>& input2_dims, int input2_shift,
1349 int16 output_activation_min, int16 output_activation_max,
1350 int16* output_data, const Dims<4>& output_dims) {
1351 static_assert(Ac == FusedActivationFunctionType::kNone ||
1352 Ac == FusedActivationFunctionType::kRelu ||
1353 Ac == FusedActivationFunctionType::kRelu6 ||
1354 Ac == FusedActivationFunctionType::kRelu1,
1355 "");
1356 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1357 if (Ac == FusedActivationFunctionType::kNone) {
1358 TFLITE_DCHECK_EQ(output_activation_min, -32768);
1359 TFLITE_DCHECK_EQ(output_activation_max, 32767);
1360 }
1361
1362 tflite::ArithmeticParams op_params;
1363 op_params.input1_shift = kReverseShift * input1_shift;
1364 op_params.input2_shift = kReverseShift * input2_shift;
1365 op_params.quantized_activation_min = output_activation_min;
1366 op_params.quantized_activation_max = output_activation_max;
1367 Add(op_params, DimsToShape(input1_dims), input1_data,
1368 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1369 output_data);
1370 }
1371
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1372 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
1373 const float* input2_data, const Dims<4>& input2_dims,
1374 float* output_data, const Dims<4>& output_dims) {
1375 float output_activation_min, output_activation_max;
1376 GetActivationMinMax(FusedActivationFunctionType::kNone,
1377 &output_activation_min, &output_activation_max);
1378 tflite::ArithmeticParams op_params;
1379 op_params.float_activation_min = output_activation_min;
1380 op_params.float_activation_max = output_activation_max;
1381 Sub(op_params, DimsToShape(input1_dims), input1_data,
1382 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1383 output_data);
1384 }
1385
1386 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1387 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
1388 const Dims<4>& input2_dims, T* output_data,
1389 const Dims<4>& output_dims) {
1390 tflite::ArithmeticParams op_params;
1391 op_params.quantized_activation_min = std::numeric_limits<T>::min();
1392 op_params.quantized_activation_max = std::numeric_limits<T>::max();
1393 Sub(op_params, DimsToShape(input1_dims), input1_data,
1394 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1395 output_data);
1396 }
1397
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1398 inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
1399 int stride_width, int stride_height, int pad_width,
1400 int pad_height, int kwidth, int kheight,
1401 float output_activation_min,
1402 float output_activation_max, float* output_data,
1403 const Dims<4>& output_dims) {
1404 tflite::PoolParams params;
1405 params.stride_height = stride_height;
1406 params.stride_width = stride_width;
1407 params.filter_height = kheight;
1408 params.filter_width = kwidth;
1409 params.padding_values.height = pad_height;
1410 params.padding_values.width = pad_width;
1411 params.float_activation_min = output_activation_min;
1412 params.float_activation_max = output_activation_max;
1413 AveragePool(params, DimsToShape(input_dims), input_data,
1414 DimsToShape(output_dims), output_data);
1415 }
1416
1417 // Transitional version that will be moved shortly to legacy_reference_ops, as
1418 // part of RuntimeShape revisions.
BroadcastMul4DSlow(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1419 inline void BroadcastMul4DSlow(const uint8* input1_data,
1420 const Dims<4>& input1_dims, int32 input1_offset,
1421 const uint8* input2_data,
1422 const Dims<4>& input2_dims, int32 input2_offset,
1423 int32 output_offset, int32 output_multiplier,
1424 int output_shift, int32 output_activation_min,
1425 int32 output_activation_max, uint8* output_data,
1426 const Dims<4>& output_dims) {
1427 tflite::ArithmeticParams op_params;
1428 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1429 op_params.input1_offset = input1_offset;
1430 op_params.input2_offset = input2_offset;
1431 op_params.output_offset = output_offset;
1432 op_params.output_multiplier = output_multiplier;
1433 op_params.output_shift = output_shift;
1434
1435 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1436 DimsToShape(input2_dims), input2_data,
1437 DimsToShape(output_dims), output_data);
1438 }
1439
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1440 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1441 int32 input1_offset, const uint8* input2_data,
1442 const Dims<4>& input2_dims, int32 input2_offset,
1443 int32 output_offset, int32 output_multiplier,
1444 int output_shift, int32 output_activation_min,
1445 int32 output_activation_max, uint8* output_data,
1446 const Dims<4>& output_dims) {
1447 BroadcastMul4DSlow(
1448 input1_data, input1_dims, input1_offset, input2_data, input2_dims,
1449 input2_offset, output_offset, output_multiplier,
1450 //
1451 kReverseShift * output_shift,
1452 //
1453 output_activation_min, output_activation_max, output_data, output_dims);
1454 }
1455
1456 // legacy, for compatibility with old checked-in code
1457 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1458 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1459 int32 input1_offset, const uint8* input2_data,
1460 const Dims<4>& input2_dims, int32 input2_offset,
1461 int32 output_offset, int32 output_multiplier,
1462 int output_shift, int32 output_activation_min,
1463 int32 output_activation_max, uint8* output_data,
1464 const Dims<4>& output_dims) {
1465 BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
1466 input2_dims, input2_offset, output_offset, output_multiplier,
1467 output_shift, output_activation_min, output_activation_max,
1468 output_data, output_dims);
1469 }
1470
1471 // legacy, for compatibility with old checked-in code
1472 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1473 void AveragePool(const float* input_data, const Dims<4>& input_dims,
1474 int stride_width, int stride_height, int pad_width,
1475 int pad_height, int kwidth, int kheight, float* output_data,
1476 const Dims<4>& output_dims) {
1477 float output_activation_min, output_activation_max;
1478 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1479
1480 AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1481 pad_height, kwidth, kheight, output_activation_min,
1482 output_activation_max, output_data, output_dims);
1483 }
1484
1485 // legacy, for compatibility with old checked-in code
1486 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1487 void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
1488 int pad_width, int pad_height, int filter_width,
1489 int filter_height, float* output_data,
1490 const Dims<4>& output_dims) {
1491 AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1492 filter_width, filter_height, output_data, output_dims);
1493 }
1494
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1495 inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1496 int stride_width, int stride_height, int pad_width,
1497 int pad_height, int filter_width, int filter_height,
1498 int32 output_activation_min,
1499 int32 output_activation_max, uint8* output_data,
1500 const Dims<4>& output_dims) {
1501 tflite::PoolParams params;
1502 params.stride_height = stride_height;
1503 params.stride_width = stride_width;
1504 params.filter_height = filter_height;
1505 params.filter_width = filter_width;
1506 params.padding_values.height = pad_height;
1507 params.padding_values.width = pad_width;
1508 params.quantized_activation_min = output_activation_min;
1509 params.quantized_activation_max = output_activation_max;
1510 AveragePool(params, DimsToShape(input_dims), input_data,
1511 DimsToShape(output_dims), output_data);
1512 }
1513
1514 // legacy, for compatibility with old checked-in code
1515 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1516 void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1517 int stride_width, int stride_height, int pad_width,
1518 int pad_height, int filter_width, int filter_height,
1519 int32 output_activation_min, int32 output_activation_max,
1520 uint8* output_data, const Dims<4>& output_dims) {
1521 static_assert(Ac == FusedActivationFunctionType::kNone ||
1522 Ac == FusedActivationFunctionType::kRelu ||
1523 Ac == FusedActivationFunctionType::kRelu6 ||
1524 Ac == FusedActivationFunctionType::kRelu1,
1525 "");
1526 if (Ac == FusedActivationFunctionType::kNone) {
1527 TFLITE_DCHECK_EQ(output_activation_min, 0);
1528 TFLITE_DCHECK_EQ(output_activation_max, 255);
1529 }
1530 AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1531 pad_height, filter_width, filter_height, output_activation_min,
1532 output_activation_max, output_data, output_dims);
1533 }
1534
1535 // legacy, for compatibility with old checked-in code
1536 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1537 void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1538 int pad_width, int pad_height, int filter_width,
1539 int filter_height, int32 output_activation_min,
1540 int32 output_activation_max, uint8* output_data,
1541 const Dims<4>& output_dims) {
1542 AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1543 filter_width, filter_height, output_activation_min,
1544 output_activation_max, output_data, output_dims);
1545 }
1546
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1547 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
1548 int stride_width, int stride_height, int pad_width,
1549 int pad_height, int kwidth, int kheight,
1550 float output_activation_min, float output_activation_max,
1551 float* output_data, const Dims<4>& output_dims) {
1552 tflite::PoolParams params;
1553 params.stride_height = stride_height;
1554 params.stride_width = stride_width;
1555 params.filter_height = kheight;
1556 params.filter_width = kwidth;
1557 params.padding_values.height = pad_height;
1558 params.padding_values.width = pad_width;
1559 params.float_activation_min = output_activation_min;
1560 params.float_activation_max = output_activation_max;
1561 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1562 output_data);
1563 }
1564
1565 // legacy, for compatibility with old checked-in code
1566 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1567 void MaxPool(const float* input_data, const Dims<4>& input_dims,
1568 int stride_width, int stride_height, int pad_width, int pad_height,
1569 int kwidth, int kheight, float* output_data,
1570 const Dims<4>& output_dims) {
1571 float output_activation_min, output_activation_max;
1572 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1573 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1574 pad_height, kwidth, kheight, output_activation_min,
1575 output_activation_max, output_data, output_dims);
1576 }
1577
1578 // legacy, for compatibility with old checked-in code
1579 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1580 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
1581 int pad_width, int pad_height, int filter_width, int filter_height,
1582 float* output_data, const Dims<4>& output_dims) {
1583 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1584 filter_width, filter_height, output_data, output_dims);
1585 }
1586
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1587 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1588 int stride_width, int stride_height, int pad_width,
1589 int pad_height, int filter_width, int filter_height,
1590 int32 output_activation_min, int32 output_activation_max,
1591 uint8* output_data, const Dims<4>& output_dims) {
1592 PoolParams params;
1593 params.stride_height = stride_height;
1594 params.stride_width = stride_width;
1595 params.filter_height = filter_height;
1596 params.filter_width = filter_width;
1597 params.padding_values.height = pad_height;
1598 params.padding_values.width = pad_width;
1599 params.quantized_activation_min = output_activation_min;
1600 params.quantized_activation_max = output_activation_max;
1601 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1602 output_data);
1603 }
1604
1605 // legacy, for compatibility with old checked-in code
1606 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1607 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1608 int stride_width, int stride_height, int pad_width, int pad_height,
1609 int filter_width, int filter_height, int32 output_activation_min,
1610 int32 output_activation_max, uint8* output_data,
1611 const Dims<4>& output_dims) {
1612 static_assert(Ac == FusedActivationFunctionType::kNone ||
1613 Ac == FusedActivationFunctionType::kRelu ||
1614 Ac == FusedActivationFunctionType::kRelu6 ||
1615 Ac == FusedActivationFunctionType::kRelu1,
1616 "");
1617 if (Ac == FusedActivationFunctionType::kNone) {
1618 TFLITE_DCHECK_EQ(output_activation_min, 0);
1619 TFLITE_DCHECK_EQ(output_activation_max, 255);
1620 }
1621 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1622 pad_height, filter_width, filter_height, output_activation_min,
1623 output_activation_max, output_data, output_dims);
1624 }
1625
1626 // legacy, for compatibility with old checked-in code
1627 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1628 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1629 int pad_width, int pad_height, int filter_width, int filter_height,
1630 int32 output_activation_min, int32 output_activation_max,
1631 uint8* output_data, const Dims<4>& output_dims) {
1632 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1633 filter_width, filter_height, output_activation_min,
1634 output_activation_max, output_data, output_dims);
1635 }
1636
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1637 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
1638 int stride_width, int stride_height, int pad_width,
1639 int pad_height, int filter_width, int filter_height,
1640 float output_activation_min, float output_activation_max,
1641 float* output_data, const Dims<4>& output_dims) {
1642 PoolParams params;
1643 params.stride_height = stride_height;
1644 params.stride_width = stride_width;
1645 params.filter_height = filter_height;
1646 params.filter_width = filter_width;
1647 params.padding_values.height = pad_height;
1648 params.padding_values.width = pad_width;
1649 params.float_activation_min = output_activation_min;
1650 params.float_activation_max = output_activation_max;
1651 L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1652 output_data);
1653 }
1654
1655 // legacy, for compatibility with old checked-in code
1656 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1657 void L2Pool(const float* input_data, const Dims<4>& input_dims,
1658 int stride_width, int stride_height, int pad_width, int pad_height,
1659 int filter_width, int filter_height, float* output_data,
1660 const Dims<4>& output_dims) {
1661 float output_activation_min, output_activation_max;
1662 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1663 L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
1664 pad_height, filter_width, filter_height, output_activation_min,
1665 output_activation_max, output_data, output_dims);
1666 }
1667
1668 // legacy, for compatibility with old checked-in code
1669 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1670 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
1671 int pad_width, int pad_height, int filter_width, int filter_height,
1672 float* output_data, const Dims<4>& output_dims) {
1673 L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1674 filter_width, filter_height, output_data, output_dims);
1675 }
1676
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)1677 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
1678 float beta, float* output_data,
1679 const Dims<4>& output_dims) {
1680 Softmax(input_data, DimsToShape(input_dims), beta, output_data,
1681 DimsToShape(output_dims));
1682 }
1683
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1684 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
1685 int32 input_beta_multiplier, int32 input_beta_left_shift,
1686 int diff_min, uint8* output_data,
1687 const Dims<4>& output_dims) {
1688 Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
1689 input_beta_left_shift, diff_min, output_data,
1690 DimsToShape(output_dims));
1691 }
1692
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1693 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
1694 float* output_data, const Dims<4>& output_dims) {
1695 LogSoftmax(input_data, DimsToShape(input_dims), output_data,
1696 DimsToShape(output_dims));
1697 }
1698
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1699 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
1700 int32 input_multiplier, int32 input_left_shift,
1701 int32 reverse_scaling_divisor,
1702 int32 reverse_scaling_right_shift, int diff_min,
1703 uint8* output_data, const Dims<4>& output_dims) {
1704 LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
1705 input_left_shift, reverse_scaling_divisor,
1706 reverse_scaling_right_shift, diff_min, output_data,
1707 DimsToShape(output_dims));
1708 }
1709
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1710 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
1711 float* output_data, const Dims<4>& output_dims) {
1712 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1713 output_data);
1714 }
1715
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1716 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
1717 int32 input_zero_point, int32 input_range_radius,
1718 int32 input_multiplier, int input_left_shift,
1719 uint8* output_data, const Dims<4>& output_dims) {
1720 Logistic(input_data, DimsToShape(input_dims), input_zero_point,
1721 input_range_radius, input_multiplier, input_left_shift, output_data,
1722 DimsToShape(output_dims));
1723 }
1724
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)1725 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
1726 int16* output_data, const Dims<4>& output_dims) {
1727 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1728 output_data);
1729 }
1730
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1731 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
1732 float* output_data, const Dims<4>& output_dims) {
1733 Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1734 output_data);
1735 }
1736
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1737 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
1738 int32 input_zero_point, int32 input_range_radius,
1739 int32 input_multiplier, int input_left_shift,
1740 uint8* output_data, const Dims<4>& output_dims) {
1741 Tanh(input_data, DimsToShape(input_dims), input_zero_point,
1742 input_range_radius, input_multiplier, input_left_shift, output_data,
1743 DimsToShape(output_dims));
1744 }
1745
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)1746 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
1747 int input_left_shift, int16* output_data,
1748 const Dims<4>& output_dims) {
1749 Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
1750 DimsToShape(output_dims));
1751 }
1752
1753 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1754 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
1755 int block_size, T* output_data,
1756 const Dims<4>& output_dims) {
1757 tflite::DepthToSpaceParams op_params;
1758 op_params.block_size = block_size;
1759
1760 DepthToSpace(op_params, DimsToShape(input_dims), input_data,
1761 DimsToShape(output_dims), output_data);
1762 }
1763
1764 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1765 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
1766 int block_size, T* output_data,
1767 const Dims<4>& output_dims) {
1768 tflite::SpaceToDepthParams op_params;
1769 op_params.block_size = block_size;
1770
1771 SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
1772 DimsToShape(output_dims), output_data);
1773 }
1774
1775 template <typename T>
Mul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1776 inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
1777 const T* input2_data, const Dims<4>& input2_dims,
1778 T output_activation_min, T output_activation_max,
1779 T* output_data, const Dims<4>& output_dims) {
1780 tflite::ArithmeticParams op_params;
1781 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1782
1783 Mul(op_params, DimsToShape(input1_dims), input1_data,
1784 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1785 output_data);
1786 }
1787
1788 // legacy, for compatibility with old checked-in code
1789 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1790 void Mul(const float* input1_data, const Dims<4>& input1_dims,
1791 const float* input2_data, const Dims<4>& input2_dims,
1792 float* output_data, const Dims<4>& output_dims) {
1793 float output_activation_min, output_activation_max;
1794 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1795
1796 tflite::ArithmeticParams op_params;
1797 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1798
1799 Mul(op_params, DimsToShape(input1_dims), input1_data,
1800 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1801 output_data);
1802 }
1803
1804 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1805 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1806 const T* input2_data, const Dims<4>& input2_dims,
1807 T output_activation_min, T output_activation_max,
1808 T* output_data, const Dims<4>& output_dims) {
1809 tflite::ArithmeticParams op_params;
1810 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1811
1812 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1813 DimsToShape(input2_dims), input2_data,
1814 DimsToShape(output_dims), output_data);
1815 }
1816
1817 // legacy, for compatibility with old checked-in code
1818 template <FusedActivationFunctionType Ac, typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1819 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1820 const T* input2_data, const Dims<4>& input2_dims,
1821 T* output_data, const Dims<4>& output_dims) {
1822 T output_activation_min, output_activation_max;
1823 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1824
1825 tflite::ArithmeticParams op_params;
1826 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1827
1828 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1829 DimsToShape(input2_dims), input2_data,
1830 DimsToShape(output_dims), output_data);
1831 }
1832
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)1833 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1834 const int16* input2_data, const Dims<4>& input2_dims,
1835 int16* output_data, const Dims<4>& output_dims) {
1836 tflite::ArithmeticParams op_params;
1837 // No params in this version.
1838
1839 Mul(op_params, DimsToShape(input1_dims), input1_data,
1840 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1841 output_data);
1842 }
1843
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1844 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1845 const int16* input2_data, const Dims<4>& input2_dims,
1846 int32 output_offset, int32 output_activation_min,
1847 int32 output_activation_max, uint8* output_data,
1848 const Dims<4>& output_dims) {
1849 tflite::ArithmeticParams op_params;
1850 op_params.quantized_activation_min = output_activation_min;
1851 op_params.quantized_activation_max = output_activation_max;
1852 op_params.output_offset = output_offset;
1853
1854 Mul(op_params, DimsToShape(input1_dims), input1_data,
1855 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1856 output_data);
1857 }
1858
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)1859 inline void LocalResponseNormalization(const float* input_data,
1860 const Dims<4>& input_dims, int range,
1861 float bias, float alpha, float beta,
1862 float* output_data,
1863 const Dims<4>& output_dims) {
1864 tflite::LocalResponseNormalizationParams op_params;
1865 op_params.range = range;
1866 op_params.bias = bias;
1867 op_params.alpha = alpha;
1868 op_params.beta = beta;
1869
1870 LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
1871 DimsToShape(output_dims), output_data);
1872 }
1873
1874 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)1875 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
1876 const Dims<4>& output_dims) {
1877 Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1878 output_data);
1879 }
1880
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1881 inline void Floor(const float* input_data, const Dims<4>& input_dims,
1882 float* output_data, const Dims<4>& output_dims) {
1883 Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1884 output_data);
1885 }
1886
1887 template <typename T>
ResizeBilinear(const T * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,T * output_data,const Dims<4> & output_dims,bool align_corners)1888 inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
1889 const int32* output_size_data,
1890 const Dims<4>& output_size_dims, T* output_data,
1891 const Dims<4>& output_dims, bool align_corners) {
1892 tflite::ResizeBilinearParams op_params;
1893 op_params.align_corners = align_corners;
1894 ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
1895 DimsToShape(output_size_dims), output_size_data,
1896 DimsToShape(output_dims), output_data);
1897 }
1898
1899 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)1900 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
1901 const int32* output_size_data,
1902 const Dims<4>& output_size_dims, float* output_data,
1903 const Dims<4>& output_dims) {
1904 ResizeBilinear<float>(input_data, input_dims, output_size_data,
1905 output_size_dims, output_data, output_dims,
1906 /*align_corners=*/false);
1907 }
1908
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)1909 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
1910 const int32* output_size_data,
1911 const Dims<4>& output_size_dims, uint8* output_data,
1912 const Dims<4>& output_dims) {
1913 ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
1914 output_size_dims, output_data, output_dims,
1915 /*align_corners=*/false);
1916 }
1917
1918 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)1919 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
1920 const int32* block_shape_data,
1921 const Dims<4>& block_shape_dims,
1922 const int32* paddings_data,
1923 const Dims<4>& paddings_dims, T* output_data,
1924 const Dims<4>& output_dims,
1925 const int32_t pad_value) {
1926 tflite::SpaceToBatchParams op_params;
1927 op_params.output_offset = pad_value;
1928
1929 SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
1930 DimsToShape(block_shape_dims), block_shape_data,
1931 DimsToShape(paddings_dims), paddings_data,
1932 DimsToShape(output_dims), output_data);
1933 }
1934
1935 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims)1936 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
1937 const int32* block_shape_data,
1938 const Dims<4>& block_shape_dims,
1939 const int32* paddings_data,
1940 const Dims<4>& paddings_dims, T* output_data,
1941 const Dims<4>& output_dims) {
1942 tflite::SpaceToBatchParams op_params;
1943 op_params.output_offset = 0;
1944
1945 SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
1946 DimsToShape(block_shape_dims), block_shape_data,
1947 DimsToShape(paddings_dims), paddings_data,
1948 DimsToShape(output_dims), output_data);
1949 }
1950
1951 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)1952 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
1953 const int32* block_shape_data,
1954 const Dims<4>& block_shape_dims,
1955 const int32* crops_data, const Dims<4>& crops_dims,
1956 T* output_data, const Dims<4>& output_dims) {
1957 BatchToSpaceND(DimsToShape(input_dims), input_data,
1958 DimsToShape(block_shape_dims), block_shape_data,
1959 DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
1960 output_data);
1961 }
1962
1963 // Legacy signature, function covered both Pad and PadV2.
1964 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)1965 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
1966 const std::vector<int>& left_paddings,
1967 const std::vector<int>& right_paddings, T* output_data,
1968 const Dims<4>& output_dims, const T pad_value) {
1969 TFLITE_DCHECK_EQ(left_paddings.size(), 4);
1970 TFLITE_DCHECK_EQ(right_paddings.size(), 4);
1971 tflite::PadParams op_params;
1972 op_params.left_padding_count = 4;
1973 op_params.right_padding_count = 4;
1974 for (int i = 0; i < 4; ++i) {
1975 op_params.left_padding[i] = left_paddings[3 - i];
1976 op_params.right_padding[i] = right_paddings[3 - i];
1977 }
1978 // SetFloatOrInt(pad_value, &op_params.pad_value);
1979 const T pad_value_copy = pad_value;
1980
1981 Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
1982 DimsToShape(output_dims), output_data);
1983 }
1984
1985 // Old Pad that calls legacy PadV2.
1986 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)1987 inline void Pad(const T* input_data, const Dims<4>& input_dims,
1988 const std::vector<int>& left_paddings,
1989 const std::vector<int>& right_paddings, T* output_data,
1990 const Dims<4>& output_dims, const int32_t pad_value) {
1991 const T converted_pad_value = static_cast<T>(pad_value);
1992 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
1993 output_dims, converted_pad_value);
1994 }
1995
1996 // Old Pad that only padded with 0.
1997 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)1998 inline void Pad(const T* input_data, const Dims<4>& input_dims,
1999 const std::vector<int>& left_paddings,
2000 const std::vector<int>& right_paddings, T* output_data,
2001 const Dims<4>& output_dims) {
2002 const T pad_value = static_cast<T>(0);
2003 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2004 output_dims, pad_value);
2005 }
2006
2007 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2008 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
2009 const T* input2_data, T* output_data,
2010 const Dims<4>& output_dims) {
2011 Minimum(DimsToShape(input1_dims), input1_data, input2_data,
2012 DimsToShape(output_dims), output_data);
2013 }
2014
2015 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2016 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
2017 const T* input2_data, T* output_data,
2018 const Dims<4>& output_dims) {
2019 Maximum(DimsToShape(input1_dims), input1_data, input2_data,
2020 DimsToShape(output_dims), output_data);
2021 }
2022
2023 template <typename T, typename Op>
TensorFlowMaximumMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims,Op op)2024 void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
2025 const T* input2_data, const Dims<4>& input2_dims,
2026 T* output_data, const Dims<4>& output_dims,
2027 Op op) {
2028 MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
2029 DimsToShape(input2_dims), input2_data,
2030 DimsToShape(output_dims), output_data, op);
2031 }
2032
2033 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)2034 void ArgMax(const T3* axis, const T1* input_data,
2035 const tflite::Dims<4>& input_dims, T2* output_data,
2036 const tflite::Dims<4>& output_dims) {
2037 // Assumes the input always has 4 dimensions, and therefore,
2038 // output always has three dimensions.
2039 auto output_shape = RuntimeShape(
2040 {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
2041 // Another way to interpret this is that output_dims.sizes[4] is always 1.
2042 TFLITE_DCHECK_EQ(output_shape.FlatSize(),
2043 DimsToShape(output_dims).FlatSize());
2044 // Legacy path only supported this.
2045 TFLITE_DCHECK_EQ(axis[0], 3);
2046 ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
2047 output_data, std::greater<T1>());
2048 }
2049
2050 template <typename T1, typename T2, typename T3, typename Cmp>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const Cmp & cmp)2051 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
2052 T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
2053 ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
2054 output_data, cmp);
2055 }
2056
2057 template <typename T>
Pow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2058 inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
2059 const T* input2_data, const Dims<4>& input2_dims,
2060 T* output_data, const Dims<4>& output_dims) {
2061 Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2062 input2_data, DimsToShape(output_dims), output_data);
2063 }
2064
2065 template <typename T>
BroadcastPow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2066 inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
2067 const T* input2_data, const Dims<4>& input2_dims,
2068 T* output_data, const Dims<4>& output_dims) {
2069 BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
2070 DimsToShape(input2_dims), input2_data,
2071 DimsToShape(output_dims), output_data);
2072 }
2073
Logical(const bool * input1_data,const Dims<4> & input1_dims,const bool * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims,const std::function<bool (bool,bool)> & func)2074 inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
2075 const bool* input2_data, const Dims<4>& input2_dims,
2076 bool* output_data, const Dims<4>& output_dims,
2077 const std::function<bool(bool, bool)>& func) {
2078 Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2079 input2_data, DimsToShape(output_dims), output_data, func);
2080 }
2081
BroadcastLogical(const bool * input1_data,const Dims<4> & input1_dims,const bool * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims,const std::function<bool (bool,bool)> & func)2082 inline void BroadcastLogical(const bool* input1_data,
2083 const Dims<4>& input1_dims,
2084 const bool* input2_data,
2085 const Dims<4>& input2_dims, bool* output_data,
2086 const Dims<4>& output_dims,
2087 const std::function<bool(bool, bool)>& func) {
2088 BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
2089 DimsToShape(input2_dims), input2_data,
2090 DimsToShape(output_dims), output_data, func);
2091 }
2092
2093 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2094 template <typename R, typename T1, typename T2>
BroadcastBinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2095 inline void BroadcastBinaryFunction(const T1* input1_data,
2096 const Dims<4>& input1_dims,
2097 const T2* input2_data,
2098 const Dims<4>& input2_dims, R* output_data,
2099 const Dims<4>& output_dims,
2100 R (*func)(T1, T2)) {
2101 BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
2102 DimsToShape(input2_dims), input2_data,
2103 DimsToShape(output_dims), output_data, func);
2104 }
2105
2106 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2107 template <typename R, typename T1, typename T2>
BinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2108 inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
2109 const T2* input2_data, const Dims<4>& input2_dims,
2110 R* output_data, const Dims<4>& output_dims,
2111 R (*func)(T1, T2)) {
2112 BinaryFunction(DimsToShape(input1_dims), input1_data,
2113 DimsToShape(input2_dims), input2_data,
2114 DimsToShape(output_dims), output_data, func);
2115 }
2116
2117 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)2118 inline void Slice(const T* input_data, const Dims<4>& input_dims,
2119 const std::vector<int>& begin, const std::vector<int>& size,
2120 T* output_data, const Dims<4>& output_dims) {
2121 tflite::SliceParams op_params;
2122 op_params.begin_count = 4;
2123 op_params.size_count = 4;
2124 for (int i = 0; i < 4; ++i) {
2125 op_params.begin[i] = begin[3 - i];
2126 op_params.size[i] = size[3 - i];
2127 }
2128
2129 Slice(op_params, DimsToShape(input_dims), input_data,
2130 DimsToShape(output_dims), output_data);
2131 }
2132
2133 } // namespace reference_ops
2134 } // namespace tflite
2135 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
2136