1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/hexagon/utils.h"
16 
17 #include <vector>
18 
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace {
26 
IsActivationReluOrNone(TfLiteFusedActivation activation)27 bool IsActivationReluOrNone(TfLiteFusedActivation activation) {
28   return (activation == kTfLiteActRelu || activation == kTfLiteActRelu6 ||
29           activation == kTfLiteActReluN1To1 || activation == kTfLiteActNone);
30 }
31 
TensorTypeMatch(int tensor_id,TfLiteContext * context,TfLiteType tensor_type)32 bool TensorTypeMatch(int tensor_id, TfLiteContext* context,
33                      TfLiteType tensor_type) {
34   const auto& tensor = context->tensors[tensor_id];
35   return tensor.type == tensor_type;
36 }
37 
38 // For each input tensor i, checks if the type matches one of the possibilities
39 // in per_input_possible_types[i].
InputsWithCorrectTypes(const TfLiteNode * node,TfLiteContext * context,const std::vector<std::vector<TfLiteType>> & per_input_possible_types)40 bool InputsWithCorrectTypes(
41     const TfLiteNode* node, TfLiteContext* context,
42     const std::vector<std::vector<TfLiteType>>& per_input_possible_types) {
43   if (node->inputs->size != per_input_possible_types.size()) return false;
44   for (int i = 0; i < per_input_possible_types.size(); ++i) {
45     // Skip optional tensor.
46     if (node->inputs->data[i] == -1) continue;
47     bool type_found = false;
48     for (auto possible_type : per_input_possible_types[i]) {
49       if (TensorTypeMatch(node->inputs->data[i], context, possible_type)) {
50         type_found = true;
51         break;
52       }
53     }
54     if (!type_found) return false;
55   }
56   return true;
57 }
58 
59 }  // namespace
60 
Get4DShape(unsigned int * batch_size,unsigned int * height_size,unsigned int * width_size,unsigned int * depth_size,TfLiteIntArray * dims)61 TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size,
62                         unsigned int* width_size, unsigned int* depth_size,
63                         TfLiteIntArray* dims) {
64   if (dims->size > 4) return kTfLiteError;
65   unsigned int* dim[] = {batch_size, height_size, width_size, depth_size};
66   for (int i = 0; i < 4; ++i) *(dim[i]) = 1;
67   for (int i = 4 - dims->size; i < 4; ++i) {
68     *dim[i] = dims->data[i - (4 - dims->size)];
69   }
70   return kTfLiteOk;
71 }
72 
73 // We maintain an op-version allowlist here to ensure we don't accept unintended
74 // ops.
CheckOpVersion(const TfLiteRegistration * registration)75 bool CheckOpVersion(const TfLiteRegistration* registration) {
76   switch (registration->builtin_code) {
77     case kTfLiteBuiltinAdd:
78     case kTfLiteBuiltinArgMax:
79     case kTfLiteBuiltinArgMin:
80     case kTfLiteBuiltinAveragePool2d:
81     case kTfLiteBuiltinConcatenation:
82     case kTfLiteBuiltinL2Normalization:
83     case kTfLiteBuiltinLogistic:
84     case kTfLiteBuiltinMaximum:
85     case kTfLiteBuiltinMaxPool2d:
86     case kTfLiteBuiltinMean:
87     case kTfLiteBuiltinMinimum:
88     case kTfLiteBuiltinMirrorPad:
89     case kTfLiteBuiltinMul:
90     case kTfLiteBuiltinPack:
91     case kTfLiteBuiltinPad:
92     case kTfLiteBuiltinQuantize:
93     case kTfLiteBuiltinRelu6:
94     case kTfLiteBuiltinSlice:
95     case kTfLiteBuiltinSoftmax:
96     case kTfLiteBuiltinSpaceToDepth:
97     case kTfLiteBuiltinDepthToSpace:
98     case kTfLiteBuiltinSplit:
99     case kTfLiteBuiltinStridedSlice:
100     case kTfLiteBuiltinSub:
101     case kTfLiteBuiltinTanh:
102     case kTfLiteBuiltinTranspose:
103       return registration->version <= 2;
104     case kTfLiteBuiltinSquaredDifference:
105     case kTfLiteBuiltinRelu:
106     case kTfLiteBuiltinRsqrt:
107       return registration->version == 2;
108     case kTfLiteBuiltinConv2d:
109     case kTfLiteBuiltinDepthwiseConv2d:
110     case kTfLiteBuiltinResizeBilinear:
111     case kTfLiteBuiltinResizeNearestNeighbor:
112     case kTfLiteBuiltinTransposeConv:
113       return registration->version <= 3;
114     case kTfLiteBuiltinFullyConnected:
115       return registration->version <= 4;
116     default:
117       return registration->version == 1;
118   }
119 }
120 
IsNodeSupportedByHexagon(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)121 bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
122                               const TfLiteNode* node, TfLiteContext* context) {
123   // Ensure all inputs & outputs have dim <= 4.
124   int tensor_id;
125   for (int i = 0; i < node->inputs->size; ++i) {
126     tensor_id = node->inputs->data[i];
127     // Skip optional tensors. Builders should handle optional tensors
128     // not available.
129     if (tensor_id == -1) continue;
130     const auto& tensor = context->tensors[tensor_id];
131     if (tensor.dims->size > 4) return false;
132   }
133   for (int i = 0; i < node->outputs->size; ++i) {
134     tensor_id = node->outputs->data[i];
135     const auto& tensor = context->tensors[tensor_id];
136     if (tensor.dims->size > 4) return false;
137   }
138 
139   if (!CheckOpVersion(registration)) return false;
140 
141   switch (registration->builtin_code) {
142     case kTfLiteBuiltinAdd: {
143       if (!InputsWithCorrectTypes(
144               node, context,
145               {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
146         return false;
147       const TfLiteAddParams* add_params =
148           reinterpret_cast<const TfLiteAddParams*>(node->builtin_data);
149       return IsActivationReluOrNone(add_params->activation);
150     }
151     case kTfLiteBuiltinMul: {
152       if (!InputsWithCorrectTypes(
153               node, context,
154               {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
155         return false;
156       const TfLiteMulParams* mul_params =
157           reinterpret_cast<const TfLiteMulParams*>(node->builtin_data);
158       // TODO(b/129276536): Add support for activation on Mul node.
159       return IsActivationReluOrNone(mul_params->activation);
160     }
161     case kTfLiteBuiltinSub: {
162       if (!InputsWithCorrectTypes(
163               node, context,
164               {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
165         return false;
166       const TfLiteSubParams* sub_params =
167           reinterpret_cast<const TfLiteSubParams*>(node->builtin_data);
168       return IsActivationReluOrNone(sub_params->activation);
169     }
170     case kTfLiteBuiltinSum:
171       // TODO(b/139277813): Enable these when they pass unit tests. These seem
172       // to recompute the output min/max instead of taking them as inputs, which
173       // causes an unexpected shift in dequantized values.
174       return false;
175     case kTfLiteBuiltinMean: {
176       return InputsWithCorrectTypes(
177                  node, context,
178                  {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
179              IsConstantTensor(GetInput(context, node, 1));
180     }
181     case kTfLiteBuiltinMirrorPad: {
182       if (!InputsWithCorrectTypes(
183               node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
184           !IsConstantTensor(GetInput(context, node, 1)))
185         return false;
186       const TfLiteMirrorPaddingParams* params =
187           reinterpret_cast<const TfLiteMirrorPaddingParams*>(
188               node->builtin_data);
189       return params->mode == kTfLiteMirrorPaddingReflect ||
190              params->mode == kTfLiteMirrorPaddingSymmetric;
191     }
192     case kTfLiteBuiltinPad: {
193       // TODO(b/139277813): Currently we only support padding with the default
194       // of 0. Add support for user-defined constant if required.
195       return (
196           node->inputs->size == 2 &&
197           InputsWithCorrectTypes(
198               node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
199           IsConstantTensor(GetInput(context, node, 1)));
200     }
201     case kTfLiteBuiltinFullyConnected: {
202       if (!InputsWithCorrectTypes(node, context,
203                                   {{kTfLiteUInt8, kTfLiteInt8},
204                                    {kTfLiteUInt8, kTfLiteInt8},
205                                    {kTfLiteInt32, kTfLiteNoType}})) {
206         return false;
207       }
208 
209       bool bias_const_or_no_bias = true;
210       if (node->inputs->data[2] != -1) {
211         const auto& bias_tensor = context->tensors[node->inputs->data[2]];
212         bias_const_or_no_bias = bias_tensor.allocation_type == kTfLiteMmapRo;
213       }
214 
215       const TfLiteFullyConnectedParams* matmul_params =
216           reinterpret_cast<const TfLiteFullyConnectedParams*>(
217               node->builtin_data);
218       return (bias_const_or_no_bias &&
219               IsActivationReluOrNone(matmul_params->activation) &&
220               matmul_params->keep_num_dims == false &&
221               matmul_params->weights_format ==
222                   kTfLiteFullyConnectedWeightsFormatDefault);
223     }
224     case kTfLiteBuiltinConcatenation: {
225       // All concatenated tensors must be 8-bit.
226       for (int i = 0; i < node->inputs->size; ++i) {
227         if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
228             !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
229           return false;
230       }
231       return true;
232     }
233     case kTfLiteBuiltinMaxPool2d: {
234       if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
235         return false;
236       // TODO(b/129276536): Add support for activation here.
237       const TfLitePoolParams* pool_params =
238           reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
239       return pool_params->activation == kTfLiteActNone;
240     }
241     case kTfLiteBuiltinAveragePool2d: {
242       if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
243         return false;
244       const TfLitePoolParams* pool_params =
245           reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
246       return (node->inputs->size == 1 &&
247               pool_params->activation == kTfLiteActNone);
248     }
249     case kTfLiteBuiltinTransposeConv: {
250       if (NumInputs(node) == 3) {
251         if (!InputsWithCorrectTypes(node, context,
252                                     {{kTfLiteInt32},
253                                      {kTfLiteUInt8, kTfLiteInt8},
254                                      {kTfLiteUInt8, kTfLiteInt8}}))
255           return false;
256       } else if (NumInputs(node) == 4) {
257         if (!InputsWithCorrectTypes(node, context,
258                                     {{kTfLiteInt32},
259                                      {kTfLiteUInt8, kTfLiteInt8},
260                                      {kTfLiteUInt8, kTfLiteInt8},
261                                      {kTfLiteInt32}}))
262           return false;
263       } else {
264         return false;
265       }
266       const TfLiteTransposeConvParams* params =
267           reinterpret_cast<const TfLiteTransposeConvParams*>(
268               node->builtin_data);
269       return (params->stride_height <= 3 && params->stride_width <= 3 &&
270               (params->padding == kTfLitePaddingSame ||
271                params->padding == kTfLitePaddingValid));
272     }
273     case kTfLiteBuiltinConv2d: {
274       if (!InputsWithCorrectTypes(node, context,
275                                   {{kTfLiteUInt8, kTfLiteInt8},
276                                    {kTfLiteUInt8, kTfLiteInt8},
277                                    {kTfLiteInt32}}))
278         return false;
279       const TfLiteConvParams* conv_params =
280           reinterpret_cast<const TfLiteConvParams*>(node->builtin_data);
281       return (IsActivationReluOrNone(conv_params->activation) &&
282               conv_params->stride_height <= 3 &&
283               conv_params->stride_width <= 3 &&
284               conv_params->dilation_height_factor == 1 &&
285               conv_params->dilation_width_factor == 1);
286     }
287     case kTfLiteBuiltinDepthwiseConv2d: {
288       if (!InputsWithCorrectTypes(node, context,
289                                   {{kTfLiteUInt8, kTfLiteInt8},
290                                    {kTfLiteUInt8, kTfLiteInt8},
291                                    {kTfLiteInt32}}))
292         return false;
293 
294       // Check dilation.
295       const TfLiteDepthwiseConvParams* conv_params =
296           reinterpret_cast<const TfLiteDepthwiseConvParams*>(
297               node->builtin_data);
298       const bool dilation = conv_params->dilation_height_factor != 1 ||
299                             conv_params->dilation_width_factor != 1;
300       if (dilation) {
301         // We only support dilations when stride == 1.
302         if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
303           return false;
304       }
305 
306       // We currently only support depth_multiplier > 1 when:
307       // 1. dilation_factor == 1 AND
308       // 2. input_depth == 1
309       // TODO(b/143759564): Add support for general case.
310       const auto& input = context->tensors[node->inputs->data[0]];
311       const bool supported_depth_multiplier =
312           conv_params->depth_multiplier == 1 ||
313           (!dilation && input.dims->size == 4 && input.dims->data[3] == 1);
314 
315       return (IsActivationReluOrNone(conv_params->activation) &&
316               conv_params->stride_height <= 3 &&
317               conv_params->stride_width <= 3 && supported_depth_multiplier);
318     }
319     case kTfLiteBuiltinReshape: {
320       if (node->inputs->size > 2 ||
321           (!TensorTypeMatch(node->inputs->data[0], context, kTfLiteUInt8) &&
322            !TensorTypeMatch(node->inputs->data[0], context, kTfLiteInt8)))
323         return false;
324       return true;
325     }
326     case kTfLiteBuiltinSoftmax: {
327       return (
328           InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}));
329     }
330     case kTfLiteBuiltinHardSwish:
331     case kTfLiteBuiltinRelu:
332     case kTfLiteBuiltinRelu6:
333     case kTfLiteBuiltinTanh:
334     case kTfLiteBuiltinLogistic: {
335       return InputsWithCorrectTypes(node, context,
336                                     {{kTfLiteUInt8, kTfLiteInt8}});
337     }
338     case kTfLiteBuiltinResizeNearestNeighbor: {
339       return InputsWithCorrectTypes(
340                  node, context,
341                  {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
342              IsConstantTensor(GetInput(context, node, 1));
343     }
344     case kTfLiteBuiltinL2Normalization: {
345       if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
346         return false;
347       const TfLiteL2NormParams* norm_params =
348           reinterpret_cast<const TfLiteL2NormParams*>(node->builtin_data);
349       return (norm_params->activation == kTfLiteActNone);
350     }
351     case kTfLiteBuiltinArgMax:
352     case kTfLiteBuiltinArgMin:
353       return InputsWithCorrectTypes(
354           node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
355     case kTfLiteBuiltinSplit: {
356       if (!InputsWithCorrectTypes(
357               node, context, {{kTfLiteInt32}, {kTfLiteUInt8, kTfLiteInt8}}))
358         return false;
359       const auto& input_tensor = context->tensors[node->inputs->data[1]];
360       const bool is_four_dim_or_less = input_tensor.dims->size < 5;
361       // We need splitting axis to be constant, so Hexagon knows output
362       // shapes.
363       return is_four_dim_or_less &&
364              IsConstantTensor(GetInput(context, node, 0));
365     }
366     case kTfLiteBuiltinResizeBilinear: {
367       if (!InputsWithCorrectTypes(
368               node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
369           !IsConstantTensor(GetInput(context, node, 1))) {
370         return false;
371       }
372       const auto& size_tensor = context->tensors[node->inputs->data[1]];
373       // TODO(b/143105433): Latency increase significantly with large size
374       // value. Limiting to 65 for now.
375       return NumElements(&size_tensor) == 2 && size_tensor.data.i32[0] < 66 &&
376              size_tensor.data.i32[1] < 66;
377     }
378     case kTfLiteBuiltinNeg: {
379       return InputsWithCorrectTypes(node, context,
380                                     {{kTfLiteUInt8, kTfLiteInt8}});
381     }
382     case kTfLiteBuiltinTranspose: {
383       return InputsWithCorrectTypes(
384           node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
385     }
386     case kTfLiteBuiltinSpaceToDepth:
387     case kTfLiteBuiltinDepthToSpace: {
388       return InputsWithCorrectTypes(node, context,
389                                     {{kTfLiteUInt8, kTfLiteInt8}});
390     }
391     case kTfLiteBuiltinQuantize: {
392       return InputsWithCorrectTypes(node, context,
393                                     {{kTfLiteUInt8, kTfLiteInt8}});
394     }
395     case kTfLiteBuiltinMinimum: {
396       return InputsWithCorrectTypes(
397           node, context,
398           {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
399     }
400     case kTfLiteBuiltinMaximum: {
401       return InputsWithCorrectTypes(
402           node, context,
403           {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
404     }
405     case kTfLiteBuiltinSlice: {
406       const auto& begins_tensor = context->tensors[node->inputs->data[1]];
407       const auto& sizes_tensor = context->tensors[node->inputs->data[2]];
408       if (!IsConstantTensor(&begins_tensor) || !IsConstantTensor(&sizes_tensor))
409         return false;
410       return InputsWithCorrectTypes(node, context,
411                                     {{kTfLiteUInt8, kTfLiteInt8},
412                                      {kTfLiteInt32, kTfLiteInt64},
413                                      {kTfLiteInt32, kTfLiteInt64}});
414     }
415     case kTfLiteBuiltinPack: {
416       // All tensors must be 8-bit.
417       for (int i = 0; i < node->inputs->size; ++i) {
418         if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
419             !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
420           return false;
421       }
422       return true;
423     }
424     case kTfLiteBuiltinStridedSlice: {
425       if (!InputsWithCorrectTypes(node, context,
426                                   {{kTfLiteUInt8, kTfLiteInt8},
427                                    {kTfLiteInt32},
428                                    {kTfLiteInt32},
429                                    {kTfLiteInt32}}))
430         return false;
431       const auto& begins_tensor = context->tensors[node->inputs->data[1]];
432       const auto& ends_tensor = context->tensors[node->inputs->data[2]];
433       const auto& step_tensor = context->tensors[node->inputs->data[3]];
434       if (!IsConstantTensor(&begins_tensor) ||
435           !IsConstantTensor(&ends_tensor) || !IsConstantTensor(&step_tensor))
436         return false;
437       const TfLiteStridedSliceParams* params =
438           reinterpret_cast<const TfLiteStridedSliceParams*>(node->builtin_data);
439       // Hexagon doesn't support ellipsis/new-axis masks.
440       return (params->ellipsis_mask == 0 && params->new_axis_mask == 0);
441     }
442     case kTfLiteBuiltinSquaredDifference: {
443       return InputsWithCorrectTypes(node, context,
444                                     {{kTfLiteInt8}, {kTfLiteInt8}});
445     }
446     case kTfLiteBuiltinRsqrt: {
447       return InputsWithCorrectTypes(node, context, {{kTfLiteInt8}});
448     }
449     default:
450       return false;
451   }
452   return false;
453 }
454 
455 }  // namespace tflite
456