1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/hexagon/utils.h"
16
17 #include <vector>
18
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace {
26
IsActivationReluOrNone(TfLiteFusedActivation activation)27 bool IsActivationReluOrNone(TfLiteFusedActivation activation) {
28 return (activation == kTfLiteActRelu || activation == kTfLiteActRelu6 ||
29 activation == kTfLiteActReluN1To1 || activation == kTfLiteActNone);
30 }
31
TensorTypeMatch(int tensor_id,TfLiteContext * context,TfLiteType tensor_type)32 bool TensorTypeMatch(int tensor_id, TfLiteContext* context,
33 TfLiteType tensor_type) {
34 const auto& tensor = context->tensors[tensor_id];
35 return tensor.type == tensor_type;
36 }
37
38 // For each input tensor i, checks if the type matches one of the possibilities
39 // in per_input_possible_types[i].
InputsWithCorrectTypes(const TfLiteNode * node,TfLiteContext * context,const std::vector<std::vector<TfLiteType>> & per_input_possible_types)40 bool InputsWithCorrectTypes(
41 const TfLiteNode* node, TfLiteContext* context,
42 const std::vector<std::vector<TfLiteType>>& per_input_possible_types) {
43 if (node->inputs->size != per_input_possible_types.size()) return false;
44 for (int i = 0; i < per_input_possible_types.size(); ++i) {
45 // Skip optional tensor.
46 if (node->inputs->data[i] == -1) continue;
47 bool type_found = false;
48 for (auto possible_type : per_input_possible_types[i]) {
49 if (TensorTypeMatch(node->inputs->data[i], context, possible_type)) {
50 type_found = true;
51 break;
52 }
53 }
54 if (!type_found) return false;
55 }
56 return true;
57 }
58
59 } // namespace
60
Get4DShape(unsigned int * batch_size,unsigned int * height_size,unsigned int * width_size,unsigned int * depth_size,TfLiteIntArray * dims)61 TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size,
62 unsigned int* width_size, unsigned int* depth_size,
63 TfLiteIntArray* dims) {
64 if (dims->size > 4) return kTfLiteError;
65 unsigned int* dim[] = {batch_size, height_size, width_size, depth_size};
66 for (int i = 0; i < 4; ++i) *(dim[i]) = 1;
67 for (int i = 4 - dims->size; i < 4; ++i) {
68 *dim[i] = dims->data[i - (4 - dims->size)];
69 }
70 return kTfLiteOk;
71 }
72
73 // We maintain an op-version allowlist here to ensure we don't accept unintended
74 // ops.
CheckOpVersion(const TfLiteRegistration * registration)75 bool CheckOpVersion(const TfLiteRegistration* registration) {
76 switch (registration->builtin_code) {
77 case kTfLiteBuiltinAdd:
78 case kTfLiteBuiltinArgMax:
79 case kTfLiteBuiltinArgMin:
80 case kTfLiteBuiltinAveragePool2d:
81 case kTfLiteBuiltinConcatenation:
82 case kTfLiteBuiltinL2Normalization:
83 case kTfLiteBuiltinLogistic:
84 case kTfLiteBuiltinMaximum:
85 case kTfLiteBuiltinMaxPool2d:
86 case kTfLiteBuiltinMean:
87 case kTfLiteBuiltinMinimum:
88 case kTfLiteBuiltinMirrorPad:
89 case kTfLiteBuiltinMul:
90 case kTfLiteBuiltinPack:
91 case kTfLiteBuiltinPad:
92 case kTfLiteBuiltinQuantize:
93 case kTfLiteBuiltinRelu6:
94 case kTfLiteBuiltinSlice:
95 case kTfLiteBuiltinSoftmax:
96 case kTfLiteBuiltinSpaceToDepth:
97 case kTfLiteBuiltinDepthToSpace:
98 case kTfLiteBuiltinSplit:
99 case kTfLiteBuiltinStridedSlice:
100 case kTfLiteBuiltinSub:
101 case kTfLiteBuiltinTanh:
102 case kTfLiteBuiltinTranspose:
103 return registration->version <= 2;
104 case kTfLiteBuiltinSquaredDifference:
105 case kTfLiteBuiltinRelu:
106 case kTfLiteBuiltinRsqrt:
107 return registration->version == 2;
108 case kTfLiteBuiltinConv2d:
109 case kTfLiteBuiltinDepthwiseConv2d:
110 case kTfLiteBuiltinResizeBilinear:
111 case kTfLiteBuiltinResizeNearestNeighbor:
112 case kTfLiteBuiltinTransposeConv:
113 return registration->version <= 3;
114 case kTfLiteBuiltinFullyConnected:
115 return registration->version <= 4;
116 default:
117 return registration->version == 1;
118 }
119 }
120
IsNodeSupportedByHexagon(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)121 bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
122 const TfLiteNode* node, TfLiteContext* context) {
123 // Ensure all inputs & outputs have dim <= 4.
124 int tensor_id;
125 for (int i = 0; i < node->inputs->size; ++i) {
126 tensor_id = node->inputs->data[i];
127 // Skip optional tensors. Builders should handle optional tensors
128 // not available.
129 if (tensor_id == -1) continue;
130 const auto& tensor = context->tensors[tensor_id];
131 if (tensor.dims->size > 4) return false;
132 }
133 for (int i = 0; i < node->outputs->size; ++i) {
134 tensor_id = node->outputs->data[i];
135 const auto& tensor = context->tensors[tensor_id];
136 if (tensor.dims->size > 4) return false;
137 }
138
139 if (!CheckOpVersion(registration)) return false;
140
141 switch (registration->builtin_code) {
142 case kTfLiteBuiltinAdd: {
143 if (!InputsWithCorrectTypes(
144 node, context,
145 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
146 return false;
147 const TfLiteAddParams* add_params =
148 reinterpret_cast<const TfLiteAddParams*>(node->builtin_data);
149 return IsActivationReluOrNone(add_params->activation);
150 }
151 case kTfLiteBuiltinMul: {
152 if (!InputsWithCorrectTypes(
153 node, context,
154 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
155 return false;
156 const TfLiteMulParams* mul_params =
157 reinterpret_cast<const TfLiteMulParams*>(node->builtin_data);
158 // TODO(b/129276536): Add support for activation on Mul node.
159 return IsActivationReluOrNone(mul_params->activation);
160 }
161 case kTfLiteBuiltinSub: {
162 if (!InputsWithCorrectTypes(
163 node, context,
164 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
165 return false;
166 const TfLiteSubParams* sub_params =
167 reinterpret_cast<const TfLiteSubParams*>(node->builtin_data);
168 return IsActivationReluOrNone(sub_params->activation);
169 }
170 case kTfLiteBuiltinSum:
171 // TODO(b/139277813): Enable these when they pass unit tests. These seem
172 // to recompute the output min/max instead of taking them as inputs, which
173 // causes an unexpected shift in dequantized values.
174 return false;
175 case kTfLiteBuiltinMean: {
176 return InputsWithCorrectTypes(
177 node, context,
178 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
179 IsConstantTensor(GetInput(context, node, 1));
180 }
181 case kTfLiteBuiltinMirrorPad: {
182 if (!InputsWithCorrectTypes(
183 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
184 !IsConstantTensor(GetInput(context, node, 1)))
185 return false;
186 const TfLiteMirrorPaddingParams* params =
187 reinterpret_cast<const TfLiteMirrorPaddingParams*>(
188 node->builtin_data);
189 return params->mode == kTfLiteMirrorPaddingReflect ||
190 params->mode == kTfLiteMirrorPaddingSymmetric;
191 }
192 case kTfLiteBuiltinPad: {
193 // TODO(b/139277813): Currently we only support padding with the default
194 // of 0. Add support for user-defined constant if required.
195 return (
196 node->inputs->size == 2 &&
197 InputsWithCorrectTypes(
198 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
199 IsConstantTensor(GetInput(context, node, 1)));
200 }
201 case kTfLiteBuiltinFullyConnected: {
202 if (!InputsWithCorrectTypes(node, context,
203 {{kTfLiteUInt8, kTfLiteInt8},
204 {kTfLiteUInt8, kTfLiteInt8},
205 {kTfLiteInt32, kTfLiteNoType}})) {
206 return false;
207 }
208
209 bool bias_const_or_no_bias = true;
210 if (node->inputs->data[2] != -1) {
211 const auto& bias_tensor = context->tensors[node->inputs->data[2]];
212 bias_const_or_no_bias = bias_tensor.allocation_type == kTfLiteMmapRo;
213 }
214
215 const TfLiteFullyConnectedParams* matmul_params =
216 reinterpret_cast<const TfLiteFullyConnectedParams*>(
217 node->builtin_data);
218 return (bias_const_or_no_bias &&
219 IsActivationReluOrNone(matmul_params->activation) &&
220 matmul_params->keep_num_dims == false &&
221 matmul_params->weights_format ==
222 kTfLiteFullyConnectedWeightsFormatDefault);
223 }
224 case kTfLiteBuiltinConcatenation: {
225 // All concatenated tensors must be 8-bit.
226 for (int i = 0; i < node->inputs->size; ++i) {
227 if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
228 !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
229 return false;
230 }
231 return true;
232 }
233 case kTfLiteBuiltinMaxPool2d: {
234 if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
235 return false;
236 // TODO(b/129276536): Add support for activation here.
237 const TfLitePoolParams* pool_params =
238 reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
239 return pool_params->activation == kTfLiteActNone;
240 }
241 case kTfLiteBuiltinAveragePool2d: {
242 if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
243 return false;
244 const TfLitePoolParams* pool_params =
245 reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
246 return (node->inputs->size == 1 &&
247 pool_params->activation == kTfLiteActNone);
248 }
249 case kTfLiteBuiltinTransposeConv: {
250 if (NumInputs(node) == 3) {
251 if (!InputsWithCorrectTypes(node, context,
252 {{kTfLiteInt32},
253 {kTfLiteUInt8, kTfLiteInt8},
254 {kTfLiteUInt8, kTfLiteInt8}}))
255 return false;
256 } else if (NumInputs(node) == 4) {
257 if (!InputsWithCorrectTypes(node, context,
258 {{kTfLiteInt32},
259 {kTfLiteUInt8, kTfLiteInt8},
260 {kTfLiteUInt8, kTfLiteInt8},
261 {kTfLiteInt32}}))
262 return false;
263 } else {
264 return false;
265 }
266 const TfLiteTransposeConvParams* params =
267 reinterpret_cast<const TfLiteTransposeConvParams*>(
268 node->builtin_data);
269 return (params->stride_height <= 3 && params->stride_width <= 3 &&
270 (params->padding == kTfLitePaddingSame ||
271 params->padding == kTfLitePaddingValid));
272 }
273 case kTfLiteBuiltinConv2d: {
274 if (!InputsWithCorrectTypes(node, context,
275 {{kTfLiteUInt8, kTfLiteInt8},
276 {kTfLiteUInt8, kTfLiteInt8},
277 {kTfLiteInt32}}))
278 return false;
279 const TfLiteConvParams* conv_params =
280 reinterpret_cast<const TfLiteConvParams*>(node->builtin_data);
281 return (IsActivationReluOrNone(conv_params->activation) &&
282 conv_params->stride_height <= 3 &&
283 conv_params->stride_width <= 3 &&
284 conv_params->dilation_height_factor == 1 &&
285 conv_params->dilation_width_factor == 1);
286 }
287 case kTfLiteBuiltinDepthwiseConv2d: {
288 if (!InputsWithCorrectTypes(node, context,
289 {{kTfLiteUInt8, kTfLiteInt8},
290 {kTfLiteUInt8, kTfLiteInt8},
291 {kTfLiteInt32}}))
292 return false;
293
294 // Check dilation.
295 const TfLiteDepthwiseConvParams* conv_params =
296 reinterpret_cast<const TfLiteDepthwiseConvParams*>(
297 node->builtin_data);
298 const bool dilation = conv_params->dilation_height_factor != 1 ||
299 conv_params->dilation_width_factor != 1;
300 if (dilation) {
301 // We only support dilations when stride == 1.
302 if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
303 return false;
304 }
305
306 // We currently only support depth_multiplier > 1 when:
307 // 1. dilation_factor == 1 AND
308 // 2. input_depth == 1
309 // TODO(b/143759564): Add support for general case.
310 const auto& input = context->tensors[node->inputs->data[0]];
311 const bool supported_depth_multiplier =
312 conv_params->depth_multiplier == 1 ||
313 (!dilation && input.dims->size == 4 && input.dims->data[3] == 1);
314
315 return (IsActivationReluOrNone(conv_params->activation) &&
316 conv_params->stride_height <= 3 &&
317 conv_params->stride_width <= 3 && supported_depth_multiplier);
318 }
319 case kTfLiteBuiltinReshape: {
320 if (node->inputs->size > 2 ||
321 (!TensorTypeMatch(node->inputs->data[0], context, kTfLiteUInt8) &&
322 !TensorTypeMatch(node->inputs->data[0], context, kTfLiteInt8)))
323 return false;
324 return true;
325 }
326 case kTfLiteBuiltinSoftmax: {
327 return (
328 InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}));
329 }
330 case kTfLiteBuiltinHardSwish:
331 case kTfLiteBuiltinRelu:
332 case kTfLiteBuiltinRelu6:
333 case kTfLiteBuiltinTanh:
334 case kTfLiteBuiltinLogistic: {
335 return InputsWithCorrectTypes(node, context,
336 {{kTfLiteUInt8, kTfLiteInt8}});
337 }
338 case kTfLiteBuiltinResizeNearestNeighbor: {
339 return InputsWithCorrectTypes(
340 node, context,
341 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
342 IsConstantTensor(GetInput(context, node, 1));
343 }
344 case kTfLiteBuiltinL2Normalization: {
345 if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
346 return false;
347 const TfLiteL2NormParams* norm_params =
348 reinterpret_cast<const TfLiteL2NormParams*>(node->builtin_data);
349 return (norm_params->activation == kTfLiteActNone);
350 }
351 case kTfLiteBuiltinArgMax:
352 case kTfLiteBuiltinArgMin:
353 return InputsWithCorrectTypes(
354 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
355 case kTfLiteBuiltinSplit: {
356 if (!InputsWithCorrectTypes(
357 node, context, {{kTfLiteInt32}, {kTfLiteUInt8, kTfLiteInt8}}))
358 return false;
359 const auto& input_tensor = context->tensors[node->inputs->data[1]];
360 const bool is_four_dim_or_less = input_tensor.dims->size < 5;
361 // We need splitting axis to be constant, so Hexagon knows output
362 // shapes.
363 return is_four_dim_or_less &&
364 IsConstantTensor(GetInput(context, node, 0));
365 }
366 case kTfLiteBuiltinResizeBilinear: {
367 if (!InputsWithCorrectTypes(
368 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
369 !IsConstantTensor(GetInput(context, node, 1))) {
370 return false;
371 }
372 const auto& size_tensor = context->tensors[node->inputs->data[1]];
373 // TODO(b/143105433): Latency increase significantly with large size
374 // value. Limiting to 65 for now.
375 return NumElements(&size_tensor) == 2 && size_tensor.data.i32[0] < 66 &&
376 size_tensor.data.i32[1] < 66;
377 }
378 case kTfLiteBuiltinNeg: {
379 return InputsWithCorrectTypes(node, context,
380 {{kTfLiteUInt8, kTfLiteInt8}});
381 }
382 case kTfLiteBuiltinTranspose: {
383 return InputsWithCorrectTypes(
384 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
385 }
386 case kTfLiteBuiltinSpaceToDepth:
387 case kTfLiteBuiltinDepthToSpace: {
388 return InputsWithCorrectTypes(node, context,
389 {{kTfLiteUInt8, kTfLiteInt8}});
390 }
391 case kTfLiteBuiltinQuantize: {
392 return InputsWithCorrectTypes(node, context,
393 {{kTfLiteUInt8, kTfLiteInt8}});
394 }
395 case kTfLiteBuiltinMinimum: {
396 return InputsWithCorrectTypes(
397 node, context,
398 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
399 }
400 case kTfLiteBuiltinMaximum: {
401 return InputsWithCorrectTypes(
402 node, context,
403 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
404 }
405 case kTfLiteBuiltinSlice: {
406 const auto& begins_tensor = context->tensors[node->inputs->data[1]];
407 const auto& sizes_tensor = context->tensors[node->inputs->data[2]];
408 if (!IsConstantTensor(&begins_tensor) || !IsConstantTensor(&sizes_tensor))
409 return false;
410 return InputsWithCorrectTypes(node, context,
411 {{kTfLiteUInt8, kTfLiteInt8},
412 {kTfLiteInt32, kTfLiteInt64},
413 {kTfLiteInt32, kTfLiteInt64}});
414 }
415 case kTfLiteBuiltinPack: {
416 // All tensors must be 8-bit.
417 for (int i = 0; i < node->inputs->size; ++i) {
418 if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
419 !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
420 return false;
421 }
422 return true;
423 }
424 case kTfLiteBuiltinStridedSlice: {
425 if (!InputsWithCorrectTypes(node, context,
426 {{kTfLiteUInt8, kTfLiteInt8},
427 {kTfLiteInt32},
428 {kTfLiteInt32},
429 {kTfLiteInt32}}))
430 return false;
431 const auto& begins_tensor = context->tensors[node->inputs->data[1]];
432 const auto& ends_tensor = context->tensors[node->inputs->data[2]];
433 const auto& step_tensor = context->tensors[node->inputs->data[3]];
434 if (!IsConstantTensor(&begins_tensor) ||
435 !IsConstantTensor(&ends_tensor) || !IsConstantTensor(&step_tensor))
436 return false;
437 const TfLiteStridedSliceParams* params =
438 reinterpret_cast<const TfLiteStridedSliceParams*>(node->builtin_data);
439 // Hexagon doesn't support ellipsis/new-axis masks.
440 return (params->ellipsis_mask == 0 && params->new_axis_mask == 0);
441 }
442 case kTfLiteBuiltinSquaredDifference: {
443 return InputsWithCorrectTypes(node, context,
444 {{kTfLiteInt8}, {kTfLiteInt8}});
445 }
446 case kTfLiteBuiltinRsqrt: {
447 return InputsWithCorrectTypes(node, context, {{kTfLiteInt8}});
448 }
449 default:
450 return false;
451 }
452 return false;
453 }
454
455 } // namespace tflite
456