1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <cstdint>
21 #include <cstring>
22 #include <limits>
23 #include <memory>
24 #include <string>
25 #include <unordered_map>
26 #include <unordered_set>
27 #include <utility>
28 #include <vector>
29 
30 #include <fp16.h>
31 #include <xnnpack.h>
32 #include "tensorflow/lite/builtin_ops.h"
33 #include "tensorflow/lite/c/builtin_op_data.h"
34 #include "tensorflow/lite/c/common.h"
35 #include "tensorflow/lite/minimal_logging.h"
36 #include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
37 
38 namespace tflite {
39 namespace xnnpack {
40 namespace {
41 
42 // Forward declaration.
43 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
44 
45 class Delegate {
46   friend class Subgraph;
47 
48  public:
Delegate(const TfLiteXNNPackDelegateOptions * options)49   explicit Delegate(const TfLiteXNNPackDelegateOptions* options) {
50 #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
51     if (options != nullptr && options->num_threads > 1) {
52       threadpool_.reset(
53           pthreadpool_create(static_cast<size_t>(options->num_threads)));
54     }
55 #endif
56     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
57                          "Created TensorFlow Lite XNNPACK delegate for CPU.");
58   }
59 
60   TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);
tflite_delegate()61   TfLiteDelegate* tflite_delegate() { return &delegate_; }
62 
threadpool() const63   pthreadpool_t threadpool() const {
64 #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
65     return nullptr;
66 #else
67     return threadpool_.get();
68 #endif
69   }
70 
71  private:
72   TfLiteDelegate delegate_ = {
73       reinterpret_cast<void*>(this),  // .data_
74       DelegatePrepare,                // .Prepare
75       nullptr,                        // .CopyFromBufferHandle
76       nullptr,                        // .CopyToBufferHandle
77       nullptr,                        // .FreeBufferHandle
78       kTfLiteDelegateFlagsNone,       // .flags
79   };
80 
81   // Unpacked data for quasi-static tensors, i.e. tensors produced by
82   // dequantizing or unpacking static buffers.
83   std::vector<char> static_unpacked_data_;
84   // Mapping from a tensor index for a quasi-static tensor to the offset to
85   // its unpacked data within static_unpacked_data_.
86   std::unordered_map<int, size_t> static_unpacked_data_map_;
87   // Set of indices of nodes which unpack static data, e.g. Dequantize
88   // operators which convert FP16 static weights to FP32. These nodes are simply
89   // ignored in the delegate implementation, because their outputs are
90   // pre-unpacked in DelegatePrepare.
91   std::unordered_set<int> static_unpack_nodes_;
92   // Set of indices of tensors with unpacked static sparse weights.
93   std::unordered_set<int> static_sparse_weights_;
94 #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
95   // Thread pool with smart-pointer for lifetime management.
96   std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_{
97       nullptr, &pthreadpool_destroy};
98 #endif
99 };
100 
101 class Subgraph {
102  public:
Create(TfLiteContext * context,const TfLiteDelegateParams * params,const Delegate * delegate)103   static Subgraph* Create(TfLiteContext* context,
104                           const TfLiteDelegateParams* params,
105                           const Delegate* delegate) {
106     // Convert subgraph inputs and outputs to hash sets for faster lookup.
107     const std::unordered_set<int> inputs(
108         &params->input_tensors->data[0],
109         &params->input_tensors->data[params->input_tensors->size]);
110     std::unordered_set<int> outputs;
111     for (int o = 0; o < params->output_tensors->size; o++) {
112       const int output_tensor_idx = params->output_tensors->data[o];
113       // Exclude quasi-static tensors which may have become subgraph outputs
114       // after partitioning.
115       if (delegate->static_unpacked_data_map_.count(output_tensor_idx) == 0) {
116         outputs.insert(output_tensor_idx);
117       }
118     }
119     std::unordered_set<int> externals(outputs);
120 
121     TfLiteIntArray* execution_plan;
122     if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
123       return nullptr;
124     }
125 
126     xnn_subgraph_t subgraph_ptr = nullptr;
127     xnn_status status = xnn_create_subgraph(
128         /*external_value_ids=*/context->tensors_size, /*flags=*/0,
129         &subgraph_ptr);
130     if (status != xnn_status_success) {
131       TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK subgraph");
132       return nullptr;
133     }
134 
135     // Smart pointer to automatically release subgraph on exit.
136     std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph(
137         subgraph_ptr, &xnn_delete_subgraph);
138 
139     bool has_sparse_weights = false;
140     // Detect which tensors are used as inputs or outputs of any subgraph nodes.
141     // -1 denotes tensor not used in the subgraph. These indexes will be
142     // filtered out and removed later.
143     std::vector<int> tensors(context->tensors_size, -1);
144     for (int i = 0; i < params->nodes_to_replace->size; i++) {
145       const int node_index = params->nodes_to_replace->data[i];
146 
147       TfLiteNode* node = nullptr;
148       TfLiteRegistration* registration = nullptr;
149       if (context->GetNodeAndRegistration(context, node_index, &node,
150                                           &registration) != kTfLiteOk) {
151         return nullptr;
152       }
153 
154       // Detect if any of the node's inputs are sparse weights.
155       if (!has_sparse_weights) {
156         for (int i = 0; i < node->inputs->size; i++) {
157           if (delegate->static_sparse_weights_.count(node->inputs->data[i]) !=
158               0) {
159             has_sparse_weights = true;
160           }
161         }
162       }
163 
164       if (delegate->static_unpack_nodes_.count(node_index) != 0) {
165         // The node unpacks static input and can be skipped because its input
166         // was pre-unpacked in DelegatePrepare.
167         continue;
168       }
169 
170       switch (registration->builtin_code) {
171         case kTfLiteBuiltinMean:
172         case kTfLiteBuiltinPad:
173         case kTfLiteBuiltinReshape:
174         case kTfLiteBuiltinResizeBilinear:
175           // Ignore the second input (axes, static padding, or new shape),
176           // because it is represented as parameters of the XNNPACK operator
177           // rather than extra input.
178           {
179             const int t = node->inputs->data[0];
180             tensors[t] = t;
181           }
182           break;
183         default:
184           // All other operators: process all inputs
185           for (int k = 0; k < node->inputs->size; k++) {
186             const int t = node->inputs->data[k];
187             tensors[t] = t;
188           }
189       }
190       for (int k = 0; k < node->outputs->size; k++) {
191         const int t = node->outputs->data[k];
192         tensors[t] = t;
193       }
194     }
195     // Filter out and remove -1 (unused) indexes.
196     tensors.erase(std::remove_if(tensors.begin(), tensors.end(),
197                                  [](int i) { return i < 0; }),
198                   tensors.end());
199     std::sort(tensors.begin(), tensors.end());
200 
201     // XNNPACK Value IDs for TFLite tensors
202     std::vector<uint32_t> xnnpack_tensors(tensors.back() + 1);
203     for (int t : tensors) {
204       if (context->tensors[t].type != kTfLiteFloat32) {
205         TF_LITE_KERNEL_LOG(
206             context,
207             "unsupported datatype (%s) of tensor %d in XNNPACK delegate",
208             TfLiteTypeGetName(context->tensors[t].type), t);
209         return nullptr;
210       }
211 
212       uint32_t flags = 0;
213       const void* data = nullptr;
214       if (context->tensors[t].allocation_type == kTfLiteMmapRo) {
215         data = context->tensors[t].data.raw_const;
216       } else {
217         // Check for quasi-static data.
218         const auto it = delegate->static_unpacked_data_map_.find(t);
219         if (it != delegate->static_unpacked_data_map_.end()) {
220           data = delegate->static_unpacked_data_.data() + it->second;
221         }
222       }
223       if (inputs.count(t) != 0) {
224         flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
225         if (data == nullptr) {
226           externals.insert(t);
227         }
228       }
229       if (outputs.count(t) != 0) {
230         flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT;
231       }
232 
233       std::vector<size_t> dims(
234           &context->tensors[t].dims->data[0],
235           &context->tensors[t].dims->data[context->tensors[t].dims->size]);
236 
237       const xnn_status status = xnn_define_tensor_value(
238           subgraph.get(), xnn_datatype_fp32, dims.size(), dims.data(), data,
239           static_cast<uint32_t>(t), flags, &xnnpack_tensors[t]);
240       if (status != xnn_status_success) {
241         TF_LITE_KERNEL_LOG(context,
242                            "failed to create XNNPACK Value for tensor %d", t);
243         return nullptr;
244       }
245     }
246 
247     // Create a set of quasi-static tensors for VisitNode function
248     std::unordered_set<int> quasi_static_tensors;
249     for (const std::pair<const int, size_t>& entry :
250          delegate->static_unpacked_data_map_) {
251       quasi_static_tensors.insert(entry.first);
252     }
253 
254     // Create XNNPACK nodes for TFLite delegate nodes
255     for (int i = 0; i < params->nodes_to_replace->size; i++) {
256       const int node_index = params->nodes_to_replace->data[i];
257       if (delegate->static_unpack_nodes_.count(node_index)) {
258         // The node unpacks static input and can be skipped because its input
259         // was pre-unpacked in DelegatePrepare.
260         continue;
261       }
262 
263       TfLiteNode* node = nullptr;
264       TfLiteRegistration* registration = nullptr;
265       if (context->GetNodeAndRegistration(context, node_index, &node,
266                                           &registration) != kTfLiteOk) {
267         return nullptr;
268       }
269 
270       if (VisitNode(subgraph.get(), context, registration, node, node_index,
271                     quasi_static_tensors, xnnpack_tensors) != kTfLiteOk) {
272         return nullptr;
273       }
274     }
275 
276     xnn_runtime_t runtime_ptr = nullptr;
277     const uint32_t flags = has_sparse_weights ? XNN_FLAG_SPARSE_INFERENCE : 0;
278     status = xnn_create_runtime_v2(subgraph.get(), delegate->threadpool(),
279                                    flags, &runtime_ptr);
280     if (status != xnn_status_success) {
281       TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime");
282       return nullptr;
283     }
284 
285     return new Subgraph(runtime_ptr, std::move(externals));
286   }
287 
Prepare(TfLiteContext * context)288   TfLiteStatus Prepare(TfLiteContext* context) { return kTfLiteOk; }
289 
Invoke(TfLiteContext * context)290   TfLiteStatus Invoke(TfLiteContext* context) {
291     if (first_run_) {
292       std::vector<xnn_external_value> external_values;
293       for (int t : externals_) {
294         xnn_external_value value = {0};
295         value.id = static_cast<uint32_t>(t);
296         value.data = context->tensors[t].data.raw;
297         external_values.push_back(value);
298       }
299 
300       const xnn_status status = xnn_setup_runtime(
301           runtime_.get(), external_values.size(), external_values.data());
302       if (status != xnn_status_success) {
303         TF_LITE_KERNEL_LOG(context, "failed to setup XNNPACK runtime");
304         return kTfLiteError;
305       }
306 
307       first_run_ = false;
308     }
309 
310     const xnn_status status = xnn_invoke_runtime(runtime_.get());
311     if (status != xnn_status_success) {
312       TF_LITE_KERNEL_LOG(context, "failed to invoke XNNPACK runtime");
313       return kTfLiteError;
314     }
315 
316     return kTfLiteOk;
317   }
318 
CalculatePadding(TfLiteContext * context,TfLitePadding padding,uint32_t * flags,int node_index)319   static TfLiteStatus CalculatePadding(TfLiteContext* context,
320                                        TfLitePadding padding, uint32_t* flags,
321                                        int node_index) {
322     switch (padding) {
323       case kTfLitePaddingSame: {
324         *flags = XNN_FLAG_TENSORFLOW_SAME_PADDING;
325         return kTfLiteOk;
326       }
327       case kTfLitePaddingValid:
328         *flags = 0;
329         return kTfLiteOk;
330       default:
331         TF_LITE_MAYBE_KERNEL_LOG(context,
332                                  "invalid padding mode (%d) in node #%d",
333                                  static_cast<int>(padding), node_index);
334         return kTfLiteError;
335     }
336   }
337 
ConvertActivationToOutputRange(TfLiteContext * context,int node_index,TfLiteFusedActivation activation,float * output_min,float * output_max)338   static TfLiteStatus ConvertActivationToOutputRange(
339       TfLiteContext* context, int node_index, TfLiteFusedActivation activation,
340       float* output_min, float* output_max) {
341     switch (activation) {
342       case kTfLiteActNone:
343         *output_min = -std::numeric_limits<float>::infinity();
344         *output_max = +std::numeric_limits<float>::infinity();
345         return kTfLiteOk;
346       case kTfLiteActRelu:
347         *output_min = 0.0f;
348         *output_max = +std::numeric_limits<float>::infinity();
349         return kTfLiteOk;
350       case kTfLiteActReluN1To1:
351         *output_min = -1.0f;
352         *output_max = +1.0f;
353         return kTfLiteOk;
354       case kTfLiteActRelu6:
355         *output_min = 0.0f;
356         *output_max = 6.0f;
357         return kTfLiteOk;
358       case kTfLiteActTanh:
359         TF_LITE_MAYBE_KERNEL_LOG(
360             context, "unsupported fused activation (Tanh) in node #%d",
361             node_index);
362         return kTfLiteError;
363       case kTfLiteActSignBit:
364         TF_LITE_MAYBE_KERNEL_LOG(
365             context, "unsupported fused activation (Sign) in node #%d",
366             node_index);
367         return kTfLiteError;
368       case kTfLiteActSigmoid:
369         TF_LITE_MAYBE_KERNEL_LOG(
370             context, "unsupported fused activation (Sigmoid) in node #%d",
371             node_index);
372         return kTfLiteError;
373       default:
374         TF_LITE_MAYBE_KERNEL_LOG(context,
375                                  "invalid fused activation (%d) in node #%d",
376                                  static_cast<int>(activation), node_index);
377         return kTfLiteError;
378     }
379   }
380 
CheckConvolutionParams(TfLiteContext * context,const TfLiteConvParams * params,int node_index)381   static TfLiteStatus CheckConvolutionParams(TfLiteContext* context,
382                                              const TfLiteConvParams* params,
383                                              int node_index) {
384     if (params->stride_width <= 0) {
385       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
386                                params->stride_width, node_index);
387       return kTfLiteError;
388     }
389     if (params->stride_height <= 0) {
390       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
391                                params->stride_height, node_index);
392       return kTfLiteError;
393     }
394 
395     if (params->dilation_width_factor <= 0) {
396       TF_LITE_MAYBE_KERNEL_LOG(context,
397                                "invalid dilation width factor %d in node #%d",
398                                params->dilation_width_factor, node_index);
399       return kTfLiteError;
400     }
401     if (params->dilation_height_factor <= 0) {
402       TF_LITE_MAYBE_KERNEL_LOG(context,
403                                "invalid dilation height factor %d in node #%d",
404                                params->dilation_height_factor, node_index);
405       return kTfLiteError;
406     }
407 
408     return kTfLiteOk;
409   }
410 
CheckDepthwiseConvolutionParams(TfLiteContext * context,const TfLiteDepthwiseConvParams * params,int output_channels,int node_index)411   static TfLiteStatus CheckDepthwiseConvolutionParams(
412       TfLiteContext* context, const TfLiteDepthwiseConvParams* params,
413       int output_channels, int node_index) {
414     if (params->stride_width <= 0) {
415       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
416                                params->stride_width, node_index);
417       return kTfLiteError;
418     }
419     if (params->stride_height <= 0) {
420       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
421                                params->stride_height, node_index);
422       return kTfLiteError;
423     }
424 
425     if (params->depth_multiplier <= 0) {
426       TF_LITE_MAYBE_KERNEL_LOG(context,
427                                "invalid depth multiplier %d in node #%d",
428                                params->depth_multiplier, node_index);
429       return kTfLiteError;
430     }
431     if (output_channels % params->depth_multiplier != 0) {
432       TF_LITE_MAYBE_KERNEL_LOG(context,
433                                "depth multiplier %d is incompatible with "
434                                "number of output channels %d in node #%d",
435                                params->depth_multiplier, output_channels,
436                                node_index);
437       return kTfLiteError;
438     }
439 
440     if (params->dilation_width_factor <= 0) {
441       TF_LITE_MAYBE_KERNEL_LOG(context,
442                                "invalid dilation width factor %d in node #%d",
443                                params->dilation_width_factor, node_index);
444       return kTfLiteError;
445     }
446     if (params->dilation_height_factor <= 0) {
447       TF_LITE_MAYBE_KERNEL_LOG(context,
448                                "invalid dilation height factor %d in node #%d",
449                                params->dilation_height_factor, node_index);
450       return kTfLiteError;
451     }
452 
453     return kTfLiteOk;
454   }
455 
CheckMediaPipeTransposedConvolutionParams(TfLiteContext * context,const TfLiteTransposeConvParams * params,int node_index)456   static TfLiteStatus CheckMediaPipeTransposedConvolutionParams(
457       TfLiteContext* context, const TfLiteTransposeConvParams* params,
458       int node_index) {
459     if (params->stride_width <= 0) {
460       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
461                                params->stride_width, node_index);
462       return kTfLiteError;
463     }
464     if (params->stride_height <= 0) {
465       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
466                                params->stride_height, node_index);
467       return kTfLiteError;
468     }
469 
470     return kTfLiteOk;
471   }
472 
CheckMediaPipePoolParams(TfLiteContext * context,const TfLitePoolParams * params,int node_index)473   static TfLiteStatus CheckMediaPipePoolParams(TfLiteContext* context,
474                                                const TfLitePoolParams* params,
475                                                int node_index) {
476     if (params->stride_width <= 0) {
477       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
478                                params->stride_width, node_index);
479       return kTfLiteError;
480     }
481     if (params->stride_height <= 0) {
482       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
483                                params->stride_height, node_index);
484       return kTfLiteError;
485     }
486     if (params->filter_width <= 0) {
487       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter width %d in node #%d",
488                                params->filter_width, node_index);
489       return kTfLiteError;
490     }
491     if (params->filter_height <= 0) {
492       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter height %d in node #%d",
493                                params->filter_height, node_index);
494       return kTfLiteError;
495     }
496     if (params->filter_width != params->stride_width) {
497       TF_LITE_MAYBE_KERNEL_LOG(
498           context, "filter width %d does not match stride width %d in node #%d",
499           params->filter_width, params->stride_width, node_index);
500       return kTfLiteError;
501     }
502     if (params->filter_height != params->stride_height) {
503       TF_LITE_MAYBE_KERNEL_LOG(
504           context,
505           "filter height %d does not match stride height %d in node #%d",
506           params->filter_height, params->stride_height, node_index);
507       return kTfLiteError;
508     }
509     switch (params->activation) {
510       case kTfLiteActNone:
511         break;
512       case kTfLiteActRelu:
513         TF_LITE_MAYBE_KERNEL_LOG(
514             context, "unsupported fused activation (Relu) in node #%d",
515             node_index);
516         return kTfLiteOk;
517       case kTfLiteActReluN1To1:
518         TF_LITE_MAYBE_KERNEL_LOG(
519             context, "unsupported fused activation (ReluMinus1To1) in node #%d",
520             node_index);
521         return kTfLiteOk;
522       case kTfLiteActRelu6:
523         TF_LITE_MAYBE_KERNEL_LOG(
524             context, "unsupported fused activation (Relu6) in node #%d",
525             node_index);
526         return kTfLiteOk;
527       case kTfLiteActTanh:
528         TF_LITE_MAYBE_KERNEL_LOG(
529             context, "unsupported fused activation (Tanh) in node #%d",
530             node_index);
531         return kTfLiteError;
532       case kTfLiteActSignBit:
533         TF_LITE_MAYBE_KERNEL_LOG(
534             context, "unsupported fused activation (Sign) in node #%d",
535             node_index);
536         return kTfLiteError;
537       case kTfLiteActSigmoid:
538         TF_LITE_MAYBE_KERNEL_LOG(
539             context, "unsupported fused activation (Sigmoid) in node #%d",
540             node_index);
541         return kTfLiteError;
542       default:
543         TF_LITE_MAYBE_KERNEL_LOG(
544             context, "invalid fused activation (%d) in node #%d",
545             static_cast<int>(params->activation), node_index);
546         return kTfLiteError;
547     }
548 
549     return kTfLiteOk;
550   }
551 
CheckFullyConnectedParams(TfLiteContext * context,const TfLiteFullyConnectedParams * params,int node_index)552   static TfLiteStatus CheckFullyConnectedParams(
553       TfLiteContext* context, const TfLiteFullyConnectedParams* params,
554       int node_index) {
555     if (params->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) {
556       TF_LITE_MAYBE_KERNEL_LOG(
557           context, "unsupported non-default weights format in node #%d",
558           node_index);
559       return kTfLiteError;
560     }
561 
562     return kTfLiteOk;
563   }
564 
CheckPoolingParams(TfLiteContext * context,const TfLitePoolParams * params,int node_index)565   static TfLiteStatus CheckPoolingParams(TfLiteContext* context,
566                                          const TfLitePoolParams* params,
567                                          int node_index) {
568     if (params->stride_width <= 0) {
569       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
570                                params->stride_width, node_index);
571       return kTfLiteError;
572     }
573     if (params->stride_height <= 0) {
574       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
575                                params->stride_height, node_index);
576       return kTfLiteError;
577     }
578 
579     if (params->filter_width <= 0) {
580       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter width %d in node #%d",
581                                params->filter_width, node_index);
582       return kTfLiteError;
583     }
584     if (params->filter_height <= 0) {
585       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter height %d in node #%d",
586                                params->filter_height, node_index);
587       return kTfLiteError;
588     }
589 
590     if (params->filter_width == 1 && params->filter_height == 1 &&
591         std::max(params->stride_width, params->stride_height) > 1) {
592       TF_LITE_MAYBE_KERNEL_LOG(context,
593                                "unsupported pooling with 1x1 filter "
594                                "and %dx%d stride in node #%d",
595                                params->stride_width, params->stride_height,
596                                node_index);
597       return kTfLiteError;
598     }
599 
600     return kTfLiteOk;
601   }
602 
CheckNumInputsAndOutputs(TfLiteContext * context,TfLiteNode * node,int expected_num_inputs,int expected_num_outputs,int node_index)603   static TfLiteStatus CheckNumInputsAndOutputs(TfLiteContext* context,
604                                                TfLiteNode* node,
605                                                int expected_num_inputs,
606                                                int expected_num_outputs,
607                                                int node_index) {
608     if (node->inputs->size != expected_num_inputs) {
609       TF_LITE_MAYBE_KERNEL_LOG(
610           context, "unexpected number of inputs (%d != %d) in node #%d",
611           node->inputs->size, expected_num_inputs, node_index);
612       return kTfLiteError;
613     }
614     if (node->outputs->size != expected_num_outputs) {
615       TF_LITE_MAYBE_KERNEL_LOG(
616           context, "unexpected number of output (%d != %d) in node #%d",
617           node->outputs->size, expected_num_outputs, node_index);
618       return kTfLiteError;
619     }
620     return kTfLiteOk;
621   }
622 
CheckTensorType(TfLiteContext * context,const TfLiteTensor & tensor,TfLiteType expected_type,int tensor_index,int node_index)623   static TfLiteStatus CheckTensorType(TfLiteContext* context,
624                                       const TfLiteTensor& tensor,
625                                       TfLiteType expected_type,
626                                       int tensor_index, int node_index) {
627     if (tensor.type != expected_type) {
628       TF_LITE_MAYBE_KERNEL_LOG(
629           context, "unsupported type %s in tensor #%d in node #%d",
630           TfLiteTypeGetName(tensor.type), tensor_index, node_index);
631       return kTfLiteError;
632     }
633     return kTfLiteOk;
634   }
635 
CheckTensorFloatType(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)636   static TfLiteStatus CheckTensorFloatType(TfLiteContext* context,
637                                            const TfLiteTensor& tensor,
638                                            int tensor_index, int node_index) {
639     return CheckTensorType(context, tensor, kTfLiteFloat32, tensor_index,
640                            node_index);
641   }
642 
CheckTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int min_num_dims,int max_num_dims,int tensor_index)643   static TfLiteStatus CheckTensorShape(TfLiteContext* context,
644                                        const TfLiteTensor& tensor,
645                                        int min_num_dims, int max_num_dims,
646                                        int tensor_index) {
647     if (min_num_dims == max_num_dims) {
648       if (tensor.dims->size != min_num_dims) {
649         TF_LITE_MAYBE_KERNEL_LOG(
650             context,
651             "unsupported number of shape dimensions (%d) in tensor #%d: "
652             "%d dimensions expected",
653             tensor.dims->size, tensor_index, min_num_dims);
654         return kTfLiteError;
655       }
656     } else {
657       if (tensor.dims->size < min_num_dims) {
658         TF_LITE_MAYBE_KERNEL_LOG(
659             context,
660             "unsupported number of shape dimensions (%d) in tensor #%d: "
661             "at least %d dimensions expected",
662             tensor.dims->size, tensor_index, min_num_dims);
663         return kTfLiteError;
664       }
665       if (tensor.dims->size > max_num_dims) {
666         TF_LITE_MAYBE_KERNEL_LOG(
667             context,
668             "unsupported number of shape dimensions (%d) in tensor #%d: "
669             "at most %d dimensions expected",
670             tensor.dims->size, tensor_index, max_num_dims);
671         return kTfLiteError;
672       }
673     }
674     for (int i = 0; i < tensor.dims->size; i++) {
675       if (tensor.dims->data[i] <= 0) {
676         TF_LITE_MAYBE_KERNEL_LOG(context,
677                                  "invalid num of elements (%d) in "
678                                  "dimension #%d in tensor #%d",
679                                  tensor.dims->data[i], i, tensor_index);
680         return kTfLiteError;
681       }
682     }
683     return kTfLiteOk;
684   }
685 
CheckTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int expected_num_dims,int tensor_index)686   static TfLiteStatus CheckTensorShape(TfLiteContext* context,
687                                        const TfLiteTensor& tensor,
688                                        int expected_num_dims,
689                                        int tensor_index) {
690     return CheckTensorShape(context, tensor, expected_num_dims,
691                             expected_num_dims, tensor_index);
692   }
693 
CheckSlopeTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)694   static TfLiteStatus CheckSlopeTensorShape(TfLiteContext* context,
695                                             const TfLiteTensor& tensor,
696                                             int tensor_index, int node_index) {
697     if (tensor.dims->size < 1) {
698       TF_LITE_MAYBE_KERNEL_LOG(context,
699                                "unexpected number of shape dimensions (%d) in "
700                                "tensor #%d in node #%d: "
701                                "expected at least a 1D tensor",
702                                tensor.dims->size, tensor_index, node_index);
703       return kTfLiteError;
704     }
705     // Validate that all non-channel dimensions (if any) are exactly 1.
706     for (int i = 0; i < tensor.dims->size - 1; i++) {
707       if (tensor.dims->data[i] != 1) {
708         TF_LITE_MAYBE_KERNEL_LOG(
709             context,
710             "unexpected value %d of shape dimension #%d in "
711             "tensor #%d in node #%d: "
712             "expected 1 for non-channel dimensions",
713             tensor.dims[i], i, tensor_index, node_index);
714         return kTfLiteError;
715       }
716     }
717     return kTfLiteOk;
718   }
719 
CheckPaddingsTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int expected_rows,int tensor_index,int node_index)720   static TfLiteStatus CheckPaddingsTensorShape(TfLiteContext* context,
721                                                const TfLiteTensor& tensor,
722                                                int expected_rows,
723                                                int tensor_index,
724                                                int node_index) {
725     if (tensor.dims->size != 2) {
726       TF_LITE_MAYBE_KERNEL_LOG(context,
727                                "unexpected number of shape dimensions (%d) in "
728                                "padding tensor #%d in node #%d: "
729                                "expected a 2D tensor",
730                                tensor.dims->size, tensor_index, node_index);
731       return kTfLiteError;
732     }
733     if (tensor.dims->data[0] != expected_rows) {
734       TF_LITE_MAYBE_KERNEL_LOG(context,
735                                "unexpected number of rows (%d) in "
736                                "padding tensor #%d in node #%d: "
737                                "%d rows expected",
738                                tensor.dims->size, tensor_index, node_index,
739                                expected_rows);
740       return kTfLiteError;
741     }
742     if (tensor.dims->data[1] != 2) {
743       TF_LITE_MAYBE_KERNEL_LOG(context,
744                                "unexpected number of columns (%d) in "
745                                "padding tensor #%d in node #%d: "
746                                "2 columns expected",
747                                tensor.dims->size, tensor_index, node_index);
748       return kTfLiteError;
749     }
750     return kTfLiteOk;
751   }
752 
CheckAxesTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)753   static TfLiteStatus CheckAxesTensorShape(TfLiteContext* context,
754                                            const TfLiteTensor& tensor,
755                                            int tensor_index, int node_index) {
756     if (tensor.dims->size != 1) {
757       TF_LITE_MAYBE_KERNEL_LOG(context,
758                                "unexpected number of shape dimensions (%d) in "
759                                "axes tensor #%d in node #%d: "
760                                "expected a 1D tensor",
761                                tensor.dims->size, tensor_index, node_index);
762       return kTfLiteError;
763     }
764     return kTfLiteOk;
765   }
766 
CheckShapeTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)767   static TfLiteStatus CheckShapeTensorShape(TfLiteContext* context,
768                                             const TfLiteTensor& tensor,
769                                             int tensor_index, int node_index) {
770     if (tensor.dims->size != 1) {
771       TF_LITE_MAYBE_KERNEL_LOG(context,
772                                "unexpected number of shape dimensions (%d) in "
773                                "shape tensor #%d in node #%d: "
774                                "expected a 1D tensor",
775                                tensor.dims->size, tensor_index, node_index);
776       return kTfLiteError;
777     }
778     return kTfLiteOk;
779   }
780 
CheckTensorNonDynamicAllocation(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)781   static TfLiteStatus CheckTensorNonDynamicAllocation(
782       TfLiteContext* context, const TfLiteTensor& tensor, int tensor_index,
783       int node_index) {
784     // TODO(b/149120844): remove checks once dynamic tensors are supported
785     if (tensor.allocation_type == kTfLiteDynamic) {
786       TF_LITE_MAYBE_KERNEL_LOG(
787           context,
788           "invalid allocation type in tensor #%d in node #%d: "
789           "expected non-dynamic tensor",
790           tensor_index, node_index);
791       return kTfLiteError;
792     }
793     return kTfLiteOk;
794   }
795 
CheckTensorStaticAllocation(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)796   static TfLiteStatus CheckTensorStaticAllocation(TfLiteContext* context,
797                                                   const TfLiteTensor& tensor,
798                                                   int tensor_index,
799                                                   int node_index) {
800     if (tensor.allocation_type != kTfLiteMmapRo ||
801         tensor.data.raw_const == nullptr) {
802       TF_LITE_MAYBE_KERNEL_LOG(
803           context,
804           "invalid allocation type in tensor #%d in node #%d: "
805           "expected static read-only tensor",
806           tensor_index, node_index);
807       return kTfLiteError;
808     }
809     return kTfLiteOk;
810   }
811 
VisitNode(xnn_subgraph_t subgraph,TfLiteContext * context,TfLiteRegistration * registration,TfLiteNode * node,int node_index,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)812   static TfLiteStatus VisitNode(
813       xnn_subgraph_t subgraph, TfLiteContext* context,
814       TfLiteRegistration* registration, TfLiteNode* node, int node_index,
815       const std::unordered_set<int>& quasi_static_tensors,
816       const std::vector<uint32_t>& xnnpack_tensors) {
817     // TFLite context used for logging purposes. When we create a new node
818     // (subgraph is non-null), logging context is the same as context, and error
819     // messages are passed to TFLite. When we detect supported operations
820     // (subgraph is null), logging context is null, and error messages are
821     // supressed.
822     TfLiteContext* logging_context = subgraph == nullptr ? nullptr : context;
823     switch (registration->builtin_code) {
824       case kTfLiteBuiltinAbs:
825         return VisitAbsNode(subgraph, logging_context, node_index, node,
826                             context->tensors, xnnpack_tensors);
827       case kTfLiteBuiltinAdd: {
828         const TfLiteAddParams* add_params =
829             static_cast<const TfLiteAddParams*>(node->builtin_data);
830 
831         return VisitAddNode(subgraph, logging_context, node_index, node,
832                             context->tensors, add_params, xnnpack_tensors);
833       }
834       case kTfLiteBuiltinAveragePool2d: {
835         const TfLitePoolParams* pool_params =
836             static_cast<const TfLitePoolParams*>(node->builtin_data);
837 
838         return VisitAveragePool2DNode(subgraph, logging_context, node_index,
839                                       node, context->tensors, pool_params,
840                                       xnnpack_tensors);
841       }
842       case kTfLiteBuiltinCeil:
843         return VisitCeilNode(subgraph, logging_context, node_index, node,
844                              context->tensors, xnnpack_tensors);
845       case kTfLiteBuiltinConv2d: {
846         const TfLiteConvParams* conv_params =
847             static_cast<const TfLiteConvParams*>(node->builtin_data);
848 
849         return VisitConv2DNode(subgraph, logging_context, node_index, node,
850                                context->tensors, conv_params,
851                                quasi_static_tensors, xnnpack_tensors);
852       }
853       case kTfLiteBuiltinDepthwiseConv2d: {
854         const TfLiteDepthwiseConvParams* dwconv_params =
855             static_cast<const TfLiteDepthwiseConvParams*>(node->builtin_data);
856 
857         return VisitDepthwiseConv2DNode(subgraph, logging_context, node_index,
858                                         node, context->tensors, dwconv_params,
859                                         quasi_static_tensors, xnnpack_tensors);
860       }
861       case kTfLiteBuiltinDepthToSpace: {
862         const TfLiteDepthToSpaceParams* depth_to_space_params =
863             static_cast<const TfLiteDepthToSpaceParams*>(node->builtin_data);
864 
865         return VisitDepthToSpaceNode(subgraph, logging_context, node_index,
866                                      node, context->tensors,
867                                      depth_to_space_params, xnnpack_tensors);
868       }
869       case kTfLiteBuiltinDiv: {
870         const TfLiteDivParams* div_params =
871             static_cast<const TfLiteDivParams*>(node->builtin_data);
872 
873         return VisitDivNode(subgraph, logging_context, node_index, node,
874                             context->tensors, div_params, xnnpack_tensors);
875       }
876       case kTfLiteBuiltinElu:
877         return VisitEluNode(subgraph, logging_context, node_index, node,
878                             context->tensors, xnnpack_tensors);
879       case kTfLiteBuiltinFullyConnected: {
880         // FullyConnected with sparse weight has version 8, which cannot be
881         // delegated to XNNPack.
882         if (registration->version == 8) {
883           TF_LITE_MAYBE_KERNEL_LOG(logging_context,
884                                    "Unsupported version %d of FullyConnected.",
885                                    registration->version);
886           return kTfLiteError;
887         }
888 
889         const TfLiteFullyConnectedParams* fc_params =
890             static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
891 
892         return VisitFullyConnectedNode(subgraph, logging_context, node_index,
893                                        node, context->tensors, fc_params,
894                                        quasi_static_tensors, xnnpack_tensors);
895       }
896       case kTfLiteBuiltinFloor:
897         return VisitFloorNode(subgraph, logging_context, node_index, node,
898                               context->tensors, xnnpack_tensors);
899       case kTfLiteBuiltinHardSwish:
900         return VisitHardSwishNode(subgraph, logging_context, node_index, node,
901                                   context->tensors, xnnpack_tensors);
902       case kTfLiteBuiltinLeakyRelu: {
903         const TfLiteLeakyReluParams* leaky_relu_params =
904             static_cast<const TfLiteLeakyReluParams*>(node->builtin_data);
905 
906         return VisitLeakyReluNode(subgraph, logging_context, node_index, node,
907                                   context->tensors, leaky_relu_params,
908                                   xnnpack_tensors);
909       }
910       case kTfLiteBuiltinLogistic:
911         return VisitLogisticNode(subgraph, logging_context, node_index, node,
912                                  context->tensors, xnnpack_tensors);
913       case kTfLiteBuiltinMaxPool2d: {
914         const TfLitePoolParams* pool_params =
915             static_cast<const TfLitePoolParams*>(node->builtin_data);
916 
917         return VisitMaxPool2DNode(subgraph, logging_context, node_index, node,
918                                   context->tensors, pool_params,
919                                   xnnpack_tensors);
920       }
921       case kTfLiteBuiltinMaximum:
922         return VisitMaximumNode(subgraph, logging_context, node_index, node,
923                                 context->tensors, xnnpack_tensors);
924       case kTfLiteBuiltinMean: {
925         const TfLiteReducerParams* reducer_params =
926             static_cast<const TfLiteReducerParams*>(node->builtin_data);
927 
928         return VisitMeanNode(subgraph, logging_context, node_index, node,
929                              context->tensors, reducer_params, xnnpack_tensors);
930       }
931       case kTfLiteBuiltinMinimum:
932         return VisitMinimumNode(subgraph, logging_context, node_index, node,
933                                 context->tensors, xnnpack_tensors);
934       case kTfLiteBuiltinMul: {
935         const TfLiteMulParams* mul_params =
936             static_cast<const TfLiteMulParams*>(node->builtin_data);
937 
938         return VisitMulNode(subgraph, logging_context, node_index, node,
939                             context->tensors, mul_params, xnnpack_tensors);
940       }
941       case kTfLiteBuiltinNeg:
942         return VisitNegNode(subgraph, logging_context, node_index, node,
943                             context->tensors, xnnpack_tensors);
944       case kTfLiteBuiltinPad:
945         return VisitPadNode(subgraph, logging_context, node_index, node,
946                             context->tensors, xnnpack_tensors);
947       case kTfLiteBuiltinPrelu:
948         return VisitPreluNode(subgraph, logging_context, node_index, node,
949                               context->tensors, quasi_static_tensors,
950                               xnnpack_tensors);
951       case kTfLiteBuiltinRelu:
952         return VisitReluNode(
953             subgraph, logging_context, node_index, node, context->tensors, 0.0f,
954             std::numeric_limits<float>::infinity(), xnnpack_tensors);
955       case kTfLiteBuiltinReluN1To1:
956         return VisitReluNode(subgraph, logging_context, node_index, node,
957                              context->tensors, -1.0f, 1.0f, xnnpack_tensors);
958       case kTfLiteBuiltinRelu6:
959         return VisitReluNode(subgraph, logging_context, node_index, node,
960                              context->tensors, 0.0f, 6.0f, xnnpack_tensors);
961       case kTfLiteBuiltinReshape: {
962         const TfLiteReshapeParams* reshape_params =
963             static_cast<const TfLiteReshapeParams*>(node->builtin_data);
964 
965         return VisitReshapeNode(subgraph, logging_context, node_index, node,
966                                 context->tensors, reshape_params,
967                                 xnnpack_tensors);
968       }
969       case kTfLiteBuiltinResizeBilinear: {
970         const TfLiteResizeBilinearParams* resize_params =
971             static_cast<const TfLiteResizeBilinearParams*>(node->builtin_data);
972 
973         return VisitResizeBilinearNode(subgraph, logging_context, node_index,
974                                        node, context->tensors, resize_params,
975                                        xnnpack_tensors);
976       }
977       case kTfLiteBuiltinRound:
978         return VisitRoundNode(subgraph, logging_context, node_index, node,
979                               context->tensors, xnnpack_tensors);
980       case kTfLiteBuiltinSoftmax: {
981         const TfLiteSoftmaxParams* softmax_params =
982             static_cast<const TfLiteSoftmaxParams*>(node->builtin_data);
983 
984         return VisitSoftmaxNode(subgraph, logging_context, node_index, node,
985                                 context->tensors, softmax_params,
986                                 xnnpack_tensors);
987       }
988       case kTfLiteBuiltinSqrt:
989         return VisitSqrtNode(subgraph, logging_context, node_index, node,
990                              context->tensors, xnnpack_tensors);
991       case kTfLiteBuiltinSquare:
992         return VisitSquareNode(subgraph, logging_context, node_index, node,
993                                context->tensors, xnnpack_tensors);
994       case kTfLiteBuiltinSquaredDifference:
995         return VisitSquaredDifferenceNode(subgraph, logging_context, node_index,
996                                           node, context->tensors,
997                                           xnnpack_tensors);
998       case kTfLiteBuiltinSub: {
999         const TfLiteSubParams* sub_params =
1000             static_cast<const TfLiteSubParams*>(node->builtin_data);
1001 
1002         return VisitSubNode(subgraph, logging_context, node_index, node,
1003                             context->tensors, sub_params, xnnpack_tensors);
1004       }
1005       case kTfLiteBuiltinCustom: {
1006         if (strcmp(registration->custom_name, "Convolution2DTransposeBias") ==
1007             0) {
1008           TfLiteTransposeConvParams deconv_params = {kTfLitePaddingUnknown};
1009           std::memcpy(&deconv_params, node->custom_initial_data,
1010                       node->custom_initial_data_size);
1011 
1012           return VisitMediaPipeDeconvolutionNode(
1013               subgraph, context, node_index, node, context->tensors,
1014               &deconv_params, quasi_static_tensors, xnnpack_tensors);
1015         } else if (strcmp(registration->custom_name,
1016                           "MaxPoolingWithArgmax2D") == 0) {
1017           TfLitePoolParams pool_params = {kTfLitePaddingUnknown};
1018           std::memcpy(&pool_params, node->custom_initial_data,
1019                       node->custom_initial_data_size);
1020 
1021           return VisitMediaPipeMaxPoolingNode(subgraph, context, node_index,
1022                                               node, context->tensors,
1023                                               &pool_params, xnnpack_tensors);
1024         } else if (strcmp(registration->custom_name, "MaxUnpooling2D") == 0) {
1025           TfLitePoolParams pool_params = {kTfLitePaddingUnknown};
1026           std::memcpy(&pool_params, node->custom_initial_data,
1027                       node->custom_initial_data_size);
1028 
1029           return VisitMediaPipeUnpoolingNode(subgraph, context, node_index,
1030                                              node, context->tensors,
1031                                              &pool_params, xnnpack_tensors);
1032         }
1033         return kTfLiteError;
1034       }
1035       default:
1036         return kTfLiteError;
1037     }
1038   }
1039 
VisitAbsNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)1040   static TfLiteStatus VisitAbsNode(
1041       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1042       TfLiteNode* node, const TfLiteTensor* tensors,
1043       const std::vector<uint32_t>& xnnpack_tensors) {
1044     TF_LITE_ENSURE_STATUS(
1045         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1046 
1047     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1048     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1049         logging_context, input_tensor, node->inputs->data[0], node_index));
1050     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1051         logging_context, input_tensor, node->inputs->data[0], node_index));
1052 
1053     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1054     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1055         logging_context, output_tensor, node->outputs->data[0], node_index));
1056     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1057         logging_context, output_tensor, node->outputs->data[0], node_index));
1058 
1059     if (subgraph != nullptr) {
1060       const xnn_status status = xnn_define_abs(
1061           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1062           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1063       if (status != xnn_status_success) {
1064         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ABS node #%d",
1065                            node_index);
1066         return kTfLiteError;
1067       }
1068     }
1069 
1070     return kTfLiteOk;
1071   }
1072 
VisitAddNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteAddParams * add_params,const std::vector<uint32_t> & xnnpack_tensors)1073   static TfLiteStatus VisitAddNode(
1074       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1075       TfLiteNode* node, const TfLiteTensor* tensors,
1076       const TfLiteAddParams* add_params,
1077       const std::vector<uint32_t>& xnnpack_tensors) {
1078     TF_LITE_ENSURE_STATUS(
1079         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
1080 
1081     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
1082     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1083         logging_context, input1_tensor, node->inputs->data[0], node_index));
1084     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1085         logging_context, input1_tensor, node->inputs->data[0], node_index));
1086 
1087     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
1088     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1089         logging_context, input2_tensor, node->inputs->data[1], node_index));
1090     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1091         logging_context, input2_tensor, node->inputs->data[1], node_index));
1092 
1093     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1094     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1095         logging_context, output_tensor, node->outputs->data[0], node_index));
1096     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1097         logging_context, output_tensor, node->outputs->data[0], node_index));
1098 
1099     float output_min = -std::numeric_limits<float>::infinity();
1100     float output_max = +std::numeric_limits<float>::infinity();
1101     if (add_params != nullptr) {
1102       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
1103           logging_context, node_index, add_params->activation, &output_min,
1104           &output_max));
1105     }
1106 
1107     if (subgraph != nullptr) {
1108       const xnn_status status = xnn_define_add2(
1109           subgraph, output_min, output_max,
1110           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
1111           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
1112           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1113       if (status != xnn_status_success) {
1114         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ADD node #%d",
1115                            node_index);
1116         return kTfLiteError;
1117       }
1118     }
1119 
1120     return kTfLiteOk;
1121   }
1122 
VisitAveragePool2DNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)1123   static TfLiteStatus VisitAveragePool2DNode(
1124       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1125       TfLiteNode* node, const TfLiteTensor* tensors,
1126       const TfLitePoolParams* pool_params,
1127       const std::vector<uint32_t>& xnnpack_tensors) {
1128     TF_LITE_ENSURE_STATUS(
1129         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1130 
1131     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1132     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1133         logging_context, input_tensor, node->inputs->data[0], node_index));
1134     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1135         logging_context, input_tensor, node->inputs->data[0], node_index));
1136 
1137     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1138     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1139         logging_context, output_tensor, node->outputs->data[0], node_index));
1140     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1141         logging_context, output_tensor, node->outputs->data[0], node_index));
1142 
1143     TF_LITE_ENSURE_STATUS(
1144         CheckPoolingParams(logging_context, pool_params, node_index));
1145 
1146     uint32_t flags = 0;
1147     TF_LITE_ENSURE_STATUS(CalculatePadding(
1148         logging_context, pool_params->padding, &flags, node_index));
1149 
1150     float output_min = -std::numeric_limits<float>::infinity();
1151     float output_max = +std::numeric_limits<float>::infinity();
1152     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
1153         logging_context, node_index, pool_params->activation, &output_min,
1154         &output_max));
1155 
1156     if (subgraph != nullptr) {
1157       xnn_status status = xnn_status_success;
1158       if (pool_params->filter_height == 1 && pool_params->filter_width == 1) {
1159         status = xnn_define_clamp(
1160             subgraph, output_min, output_max,
1161             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1162             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1163       } else {
1164         status = xnn_define_average_pooling_2d(
1165             subgraph,
1166             /*input_padding_top=*/0,
1167             /*input_padding_right=*/0,
1168             /*input_padding_bottom=*/0,
1169             /*input_padding_left=*/0,
1170             static_cast<uint32_t>(pool_params->filter_height),
1171             static_cast<uint32_t>(pool_params->filter_width),
1172             static_cast<uint32_t>(pool_params->stride_height),
1173             static_cast<uint32_t>(pool_params->stride_width), output_min,
1174             output_max,
1175             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1176             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
1177       }
1178       if (status != xnn_status_success) {
1179         TF_LITE_KERNEL_LOG(logging_context,
1180                            "failed to delegate AVERAGE_POOL_2D node #%d",
1181                            node_index);
1182         return kTfLiteError;
1183       }
1184     }
1185 
1186     return kTfLiteOk;
1187   }
1188 
VisitCeilNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)1189   static TfLiteStatus VisitCeilNode(
1190       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1191       TfLiteNode* node, const TfLiteTensor* tensors,
1192       const std::vector<uint32_t>& xnnpack_tensors) {
1193     TF_LITE_ENSURE_STATUS(
1194         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1195 
1196     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1197     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1198         logging_context, input_tensor, node->inputs->data[0], node_index));
1199     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1200         logging_context, input_tensor, node->inputs->data[0], node_index));
1201 
1202     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1203     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1204         logging_context, output_tensor, node->outputs->data[0], node_index));
1205     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1206         logging_context, output_tensor, node->outputs->data[0], node_index));
1207 
1208     if (subgraph != nullptr) {
1209       const xnn_status status = xnn_define_ceiling(
1210           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1211           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1212       if (status != xnn_status_success) {
1213         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate CEIL node #%d",
1214                            node_index);
1215         return kTfLiteError;
1216       }
1217     }
1218 
1219     return kTfLiteOk;
1220   }
1221 
VisitConv2DNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteConvParams * conv_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)1222   static TfLiteStatus VisitConv2DNode(
1223       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1224       TfLiteNode* node, const TfLiteTensor* tensors,
1225       const TfLiteConvParams* conv_params,
1226       const std::unordered_set<int>& quasi_static_tensors,
1227       const std::vector<uint32_t>& xnnpack_tensors) {
1228     TF_LITE_ENSURE_STATUS(
1229         CheckConvolutionParams(logging_context, conv_params, node_index));
1230 
1231     TF_LITE_ENSURE_STATUS(
1232         CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
1233 
1234     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1235     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1236         logging_context, input_tensor, node->inputs->data[0], node_index));
1237     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
1238                                            node->inputs->data[0]));
1239     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1240         logging_context, input_tensor, node->inputs->data[0], node_index));
1241 
1242     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
1243     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1244         logging_context, filter_tensor, node->inputs->data[1], node_index));
1245     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
1246                                            node->inputs->data[1]));
1247     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
1248       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
1249           logging_context, filter_tensor, node->inputs->data[1], node_index));
1250     }
1251 
1252     const int bias_tensor_id = node->inputs->data[2];
1253     if (bias_tensor_id < 0) {
1254       TF_LITE_MAYBE_KERNEL_LOG(logging_context,
1255                                "unsupported CONV_2D node #%d without bias",
1256                                node_index);
1257       return kTfLiteError;
1258     }
1259     const TfLiteTensor& bias_tensor = tensors[bias_tensor_id];
1260     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1261         logging_context, bias_tensor, node->inputs->data[2], node_index));
1262     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
1263                                            node->inputs->data[2]));
1264     if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
1265       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
1266           logging_context, bias_tensor, node->inputs->data[2], node_index));
1267     }
1268 
1269     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1270     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1271         logging_context, output_tensor, node->outputs->data[0], node_index));
1272     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
1273                                            node->outputs->data[0]));
1274     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1275         logging_context, output_tensor, node->outputs->data[0], node_index));
1276 
1277     const int output_channels = filter_tensor.dims->data[0];
1278     const int kernel_height = filter_tensor.dims->data[1];
1279     const int kernel_width = filter_tensor.dims->data[2];
1280     const int input_channels = filter_tensor.dims->data[3];
1281 
1282     uint32_t flags;
1283     TF_LITE_ENSURE_STATUS(CalculatePadding(
1284         logging_context, conv_params->padding, &flags, node_index));
1285 
1286     float output_min = -std::numeric_limits<float>::infinity();
1287     float output_max = +std::numeric_limits<float>::infinity();
1288     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
1289         logging_context, node_index, conv_params->activation, &output_min,
1290         &output_max));
1291 
1292     if (subgraph != nullptr) {
1293       const xnn_status status = xnn_define_convolution_2d(
1294           subgraph,
1295           /*input_padding_top=*/0,
1296           /*input_padding_right=*/0,
1297           /*input_padding_bottom=*/0,
1298           /*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
1299           static_cast<uint32_t>(kernel_width),
1300           static_cast<uint32_t>(conv_params->stride_height),
1301           static_cast<uint32_t>(conv_params->stride_width),
1302           static_cast<uint32_t>(conv_params->dilation_height_factor),
1303           static_cast<uint32_t>(conv_params->dilation_width_factor),
1304           /*groups=*/1, static_cast<size_t>(input_channels),
1305           static_cast<size_t>(output_channels), output_min, output_max,
1306           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1307           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
1308           /*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
1309           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
1310       if (status != xnn_status_success) {
1311         TF_LITE_KERNEL_LOG(logging_context,
1312                            "failed to delegate CONV_2D node #%d", node_index);
1313         return kTfLiteError;
1314       }
1315     }
1316 
1317     return kTfLiteOk;
1318   }
1319 
VisitDepthwiseConv2DNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteDepthwiseConvParams * dwconv_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)1320   static TfLiteStatus VisitDepthwiseConv2DNode(
1321       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1322       TfLiteNode* node, const TfLiteTensor* tensors,
1323       const TfLiteDepthwiseConvParams* dwconv_params,
1324       const std::unordered_set<int>& quasi_static_tensors,
1325       const std::vector<uint32_t>& xnnpack_tensors) {
1326     TF_LITE_ENSURE_STATUS(
1327         CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
1328 
1329     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1330     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1331         logging_context, input_tensor, node->inputs->data[0], node_index));
1332     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
1333                                            node->inputs->data[0]));
1334     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1335         logging_context, input_tensor, node->inputs->data[0], node_index));
1336 
1337     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
1338     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1339         logging_context, filter_tensor, node->inputs->data[1], node_index));
1340     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
1341                                            node->inputs->data[1]));
1342     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
1343       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
1344           logging_context, filter_tensor, node->inputs->data[1], node_index));
1345     }
1346 
1347     const int bias_tensor_id = node->inputs->data[2];
1348     if (bias_tensor_id < 0) {
1349       TF_LITE_MAYBE_KERNEL_LOG(
1350           logging_context,
1351           "unsupported DEPTHWISE_CONV_2D node #%d without bias", node_index);
1352       return kTfLiteError;
1353     }
1354     const TfLiteTensor& bias_tensor = tensors[bias_tensor_id];
1355     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1356         logging_context, bias_tensor, node->inputs->data[2], node_index));
1357     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
1358                                            node->inputs->data[2]));
1359     if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
1360       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
1361           logging_context, bias_tensor, node->inputs->data[2], node_index));
1362     }
1363 
1364     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1365     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1366         logging_context, output_tensor, node->outputs->data[0], node_index));
1367     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
1368                                            node->outputs->data[0]));
1369     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1370         logging_context, output_tensor, node->outputs->data[0], node_index));
1371 
1372     const int kernel_height = filter_tensor.dims->data[1];
1373     const int kernel_width = filter_tensor.dims->data[2];
1374     const int output_channels = filter_tensor.dims->data[3];
1375 
1376     TF_LITE_ENSURE_STATUS(CheckDepthwiseConvolutionParams(
1377         logging_context, dwconv_params, output_channels, node_index));
1378 
1379     uint32_t flags = 0;
1380     TF_LITE_ENSURE_STATUS(CalculatePadding(
1381         logging_context, dwconv_params->padding, &flags, node_index));
1382 
1383     float output_min = -std::numeric_limits<float>::infinity();
1384     float output_max = +std::numeric_limits<float>::infinity();
1385     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
1386         logging_context, node_index, dwconv_params->activation, &output_min,
1387         &output_max));
1388 
1389     if (subgraph != nullptr) {
1390       const xnn_status status = xnn_define_depthwise_convolution_2d(
1391           subgraph,
1392           /*input_padding_top=*/0,
1393           /*input_padding_right=*/0,
1394           /*input_padding_bottom=*/0,
1395           /*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
1396           static_cast<uint32_t>(kernel_width),
1397           static_cast<uint32_t>(dwconv_params->stride_height),
1398           static_cast<uint32_t>(dwconv_params->stride_width),
1399           static_cast<uint32_t>(dwconv_params->dilation_height_factor),
1400           static_cast<uint32_t>(dwconv_params->dilation_width_factor),
1401           static_cast<uint32_t>(dwconv_params->depth_multiplier),
1402           /*input_channels=*/
1403           static_cast<uint32_t>(output_channels /
1404                                 dwconv_params->depth_multiplier),
1405           output_min, output_max,
1406           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1407           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
1408           /*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
1409           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
1410       if (status != xnn_status_success) {
1411         TF_LITE_KERNEL_LOG(logging_context,
1412                            "failed to delegate DEPTHWISE_CONV_2D node #%d",
1413                            node_index);
1414         return kTfLiteError;
1415       }
1416     }
1417 
1418     return kTfLiteOk;
1419   }
1420 
VisitDepthToSpaceNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteDepthToSpaceParams * depth_to_space_params,const std::vector<uint32_t> & xnnpack_tensors)1421   static TfLiteStatus VisitDepthToSpaceNode(
1422       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1423       TfLiteNode* node, const TfLiteTensor* tensors,
1424       const TfLiteDepthToSpaceParams* depth_to_space_params,
1425       const std::vector<uint32_t>& xnnpack_tensors) {
1426     TF_LITE_ENSURE_STATUS(
1427         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1428 
1429     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1430     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1431         logging_context, input_tensor, node->inputs->data[0], node_index));
1432     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1433         logging_context, input_tensor, node->inputs->data[0], node_index));
1434 
1435     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1436     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1437         logging_context, output_tensor, node->outputs->data[0], node_index));
1438     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1439         logging_context, output_tensor, node->outputs->data[0], node_index));
1440 
1441     if (depth_to_space_params->block_size <= 1) {
1442       TF_LITE_MAYBE_KERNEL_LOG(
1443           logging_context, "invalid block size (%d) in DEPTH_TO_SPACE node #%d",
1444           depth_to_space_params->block_size, node_index);
1445       return kTfLiteError;
1446     }
1447 
1448     if (subgraph != nullptr) {
1449       const xnn_status status = xnn_define_depth_to_space(
1450           subgraph,
1451           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1452           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
1453           /*block_size=*/
1454           static_cast<uint32_t>(depth_to_space_params->block_size),
1455           /*flags=*/0);
1456       if (status != xnn_status_success) {
1457         TF_LITE_KERNEL_LOG(logging_context,
1458                            "failed to delegate DEPTH_TO_SPACE node #%d",
1459                            node_index);
1460         return kTfLiteError;
1461       }
1462     }
1463 
1464     return kTfLiteOk;
1465   }
1466 
VisitDivNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteDivParams * div_params,const std::vector<uint32_t> & xnnpack_tensors)1467   static TfLiteStatus VisitDivNode(
1468       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1469       TfLiteNode* node, const TfLiteTensor* tensors,
1470       const TfLiteDivParams* div_params,
1471       const std::vector<uint32_t>& xnnpack_tensors) {
1472     TF_LITE_ENSURE_STATUS(
1473         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
1474 
1475     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
1476     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1477         logging_context, input1_tensor, node->inputs->data[0], node_index));
1478     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1479         logging_context, input1_tensor, node->inputs->data[0], node_index));
1480 
1481     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
1482     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1483         logging_context, input2_tensor, node->inputs->data[1], node_index));
1484     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1485         logging_context, input2_tensor, node->inputs->data[1], node_index));
1486 
1487     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1488     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1489         logging_context, output_tensor, node->outputs->data[0], node_index));
1490     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1491         logging_context, output_tensor, node->outputs->data[0], node_index));
1492 
1493     float output_min = -std::numeric_limits<float>::infinity();
1494     float output_max = +std::numeric_limits<float>::infinity();
1495     if (div_params != nullptr) {
1496       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
1497           logging_context, node_index, div_params->activation, &output_min,
1498           &output_max));
1499     }
1500 
1501     if (subgraph != nullptr) {
1502       const xnn_status status = xnn_define_divide(
1503           subgraph, output_min, output_max,
1504           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
1505           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
1506           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1507       if (status != xnn_status_success) {
1508         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate DIV node #%d",
1509                            node_index);
1510         return kTfLiteError;
1511       }
1512     }
1513 
1514     return kTfLiteOk;
1515   }
1516 
VisitEluNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)1517   static TfLiteStatus VisitEluNode(
1518       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1519       TfLiteNode* node, const TfLiteTensor* tensors,
1520       const std::vector<uint32_t>& xnnpack_tensors) {
1521     TF_LITE_ENSURE_STATUS(
1522         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1523 
1524     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1525     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1526         logging_context, input_tensor, node->inputs->data[0], node_index));
1527     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1528         logging_context, input_tensor, node->inputs->data[0], node_index));
1529 
1530     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1531     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1532         logging_context, output_tensor, node->outputs->data[0], node_index));
1533     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1534         logging_context, output_tensor, node->outputs->data[0], node_index));
1535 
1536     if (subgraph != nullptr) {
1537       const xnn_status status =
1538           xnn_define_elu(subgraph, /*alpha=*/1.0f,
1539                          /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1540                          /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
1541                          /*flags=*/0);
1542       if (status != xnn_status_success) {
1543         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ELU node #%d",
1544                            node_index);
1545         return kTfLiteError;
1546       }
1547     }
1548 
1549     return kTfLiteOk;
1550   }
1551 
VisitFullyConnectedNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteFullyConnectedParams * fc_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)1552   static TfLiteStatus VisitFullyConnectedNode(
1553       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1554       TfLiteNode* node, const TfLiteTensor* tensors,
1555       const TfLiteFullyConnectedParams* fc_params,
1556       const std::unordered_set<int>& quasi_static_tensors,
1557       const std::vector<uint32_t>& xnnpack_tensors) {
1558     TF_LITE_ENSURE_STATUS(
1559         CheckFullyConnectedParams(logging_context, fc_params, node_index));
1560 
1561     TF_LITE_ENSURE_STATUS(
1562         CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
1563 
1564     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1565     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1566         logging_context, input_tensor, node->inputs->data[0], node_index));
1567     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1568         logging_context, input_tensor, node->inputs->data[0], node_index));
1569 
1570     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
1571     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1572         logging_context, filter_tensor, node->inputs->data[1], node_index));
1573     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 2,
1574                                            node->inputs->data[1]));
1575     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
1576       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
1577           logging_context, filter_tensor, node->inputs->data[1], node_index));
1578     }
1579 
1580     const int bias_tensor_id = node->inputs->data[2];
1581     if (bias_tensor_id < 0) {
1582       TF_LITE_MAYBE_KERNEL_LOG(
1583           logging_context, "unsupported FULLY_CONNECTED node #%d without bias",
1584           node_index);
1585       return kTfLiteError;
1586     }
1587     const TfLiteTensor& bias_tensor = tensors[bias_tensor_id];
1588     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1589         logging_context, bias_tensor, node->inputs->data[2], node_index));
1590     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
1591                                            node->inputs->data[2]));
1592     if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
1593       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
1594           logging_context, bias_tensor, node->inputs->data[2], node_index));
1595     }
1596 
1597     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1598     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1599         logging_context, output_tensor, node->outputs->data[0], node_index));
1600     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1601         logging_context, output_tensor, node->outputs->data[0], node_index));
1602 
1603     const int32_t output_channels = filter_tensor.dims->data[0];
1604     const int32_t input_channels = filter_tensor.dims->data[1];
1605 
1606     if (input_tensor.dims->size == 0) {
1607       TF_LITE_MAYBE_KERNEL_LOG(
1608           logging_context,
1609           "unexpected number of shape dimensions %d in tensor #%d",
1610           input_tensor.dims->size, node->inputs->data[0]);
1611       return kTfLiteError;
1612     }
1613 
1614     int32_t num_input_elements = 1;
1615     for (int i = 0; i < input_tensor.dims->size; i++) {
1616       if (input_tensor.dims->data[i] <= 0) {
1617         TF_LITE_MAYBE_KERNEL_LOG(
1618             logging_context, "invalid dimension #%d (%d) in tensor #%d", i,
1619             input_tensor.dims->data[i], node->inputs->data[0]);
1620         return kTfLiteError;
1621       }
1622       num_input_elements *= input_tensor.dims->data[i];
1623     }
1624 
1625     if (fc_params->keep_num_dims) {
1626       TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor,
1627                                              input_tensor.dims->size,
1628                                              node->outputs->data[0]));
1629 
1630       for (int i = 0; i < input_tensor.dims->size - 1; i++) {
1631         if (input_tensor.dims->data[i] != output_tensor.dims->data[i]) {
1632           TF_LITE_MAYBE_KERNEL_LOG(
1633               logging_context,
1634               "mismatch in shape dimension %d (%d != %d) in input and output "
1635               "tensors of FULLY_CONNECTED operator #%d",
1636               i, input_tensor.dims->data[i], output_tensor.dims->data[i],
1637               node_index);
1638           return kTfLiteError;
1639         }
1640       }
1641     } else {
1642       if (num_input_elements % input_channels != 0) {
1643         TF_LITE_MAYBE_KERNEL_LOG(
1644             logging_context,
1645             "number of elements in input tensor #%d in FULLY_CONNECTED "
1646             "operator is not divisible by input channels (%d)",
1647             node->inputs->data[0], input_channels);
1648         return kTfLiteError;
1649       }
1650 
1651       TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 2,
1652                                              node->outputs->data[0]));
1653 
1654       if (output_tensor.dims->data[0] != num_input_elements / input_channels) {
1655         TF_LITE_MAYBE_KERNEL_LOG(
1656             logging_context,
1657             "batch size %d in output tensor #%d in FULLY_CONNECTED operator "
1658             "does not match batch size %d in reshaped input tensor #%d",
1659             output_tensor.dims->data[0], node->outputs->data[0],
1660             num_input_elements / input_channels, node->inputs->data[0]);
1661         return kTfLiteError;
1662       }
1663     }
1664 
1665     if (output_tensor.dims->data[output_tensor.dims->size - 1] !=
1666         output_channels) {
1667       TF_LITE_MAYBE_KERNEL_LOG(
1668           logging_context,
1669           "number of channels %d in output tensor #%d does not match output "
1670           "channels %d in filter tensor #%d",
1671           output_tensor.dims->data[output_tensor.dims->size - 1],
1672           node->outputs->data[0], output_channels, node->inputs->data[1]);
1673       return kTfLiteError;
1674     }
1675 
1676     float output_min = -std::numeric_limits<float>::infinity();
1677     float output_max = +std::numeric_limits<float>::infinity();
1678     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
1679         logging_context, node_index, fc_params->activation, &output_min,
1680         &output_max));
1681 
1682     if (subgraph != nullptr) {
1683       const xnn_status status = xnn_define_fully_connected(
1684           subgraph, output_min, output_max,
1685           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1686           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
1687           /*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
1688           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
1689           /*flags=*/fc_params->keep_num_dims ? 0
1690                                              : XNN_FLAG_TENSORFLOW_RESHAPE_2D);
1691       if (status != xnn_status_success) {
1692         TF_LITE_KERNEL_LOG(logging_context,
1693                            "failed to delegate FULLY_CONNECTED node #%d",
1694                            node_index);
1695         return kTfLiteError;
1696       }
1697     }
1698 
1699     return kTfLiteOk;
1700   }
1701 
VisitFloorNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)1702   static TfLiteStatus VisitFloorNode(
1703       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1704       TfLiteNode* node, const TfLiteTensor* tensors,
1705       const std::vector<uint32_t>& xnnpack_tensors) {
1706     TF_LITE_ENSURE_STATUS(
1707         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1708 
1709     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1710     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1711         logging_context, input_tensor, node->inputs->data[0], node_index));
1712     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1713         logging_context, input_tensor, node->inputs->data[0], node_index));
1714 
1715     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1716     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1717         logging_context, output_tensor, node->outputs->data[0], node_index));
1718     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1719         logging_context, output_tensor, node->outputs->data[0], node_index));
1720 
1721     if (subgraph != nullptr) {
1722       const xnn_status status = xnn_define_floor(
1723           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1724           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1725       if (status != xnn_status_success) {
1726         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate FLOOR node #%d",
1727                            node_index);
1728         return kTfLiteError;
1729       }
1730     }
1731 
1732     return kTfLiteOk;
1733   }
1734 
VisitHardSwishNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)1735   static TfLiteStatus VisitHardSwishNode(
1736       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1737       TfLiteNode* node, const TfLiteTensor* tensors,
1738       const std::vector<uint32_t>& xnnpack_tensors) {
1739     TF_LITE_ENSURE_STATUS(
1740         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1741 
1742     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1743     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1744         logging_context, input_tensor, node->inputs->data[0], node_index));
1745     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1746         logging_context, input_tensor, node->inputs->data[0], node_index));
1747 
1748     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1749     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1750         logging_context, output_tensor, node->outputs->data[0], node_index));
1751     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1752         logging_context, output_tensor, node->outputs->data[0], node_index));
1753 
1754     if (subgraph != nullptr) {
1755       const xnn_status status = xnn_define_hardswish(
1756           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1757           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1758       if (status != xnn_status_success) {
1759         TF_LITE_KERNEL_LOG(logging_context,
1760                            "failed to delegate HARD_SWISH node #%d",
1761                            node_index);
1762         return kTfLiteError;
1763       }
1764     }
1765 
1766     return kTfLiteOk;
1767   }
1768 
VisitLeakyReluNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteLeakyReluParams * leaky_relu_params,const std::vector<uint32_t> & xnnpack_tensors)1769   static TfLiteStatus VisitLeakyReluNode(
1770       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1771       TfLiteNode* node, const TfLiteTensor* tensors,
1772       const TfLiteLeakyReluParams* leaky_relu_params,
1773       const std::vector<uint32_t>& xnnpack_tensors) {
1774     TF_LITE_ENSURE_STATUS(
1775         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1776 
1777     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1778     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1779         logging_context, input_tensor, node->inputs->data[0], node_index));
1780     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1781         logging_context, input_tensor, node->inputs->data[0], node_index));
1782 
1783     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1784     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1785         logging_context, output_tensor, node->outputs->data[0], node_index));
1786     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1787         logging_context, output_tensor, node->outputs->data[0], node_index));
1788 
1789     if (subgraph != nullptr) {
1790       const xnn_status status = xnn_define_leaky_relu(
1791           subgraph, leaky_relu_params->alpha,
1792           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1793           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1794       if (status != xnn_status_success) {
1795         TF_LITE_KERNEL_LOG(logging_context,
1796                            "failed to delegate LEAKY_RELU node #%d",
1797                            node_index);
1798         return kTfLiteError;
1799       }
1800     }
1801 
1802     return kTfLiteOk;
1803   }
1804 
VisitLogisticNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)1805   static TfLiteStatus VisitLogisticNode(
1806       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1807       TfLiteNode* node, const TfLiteTensor* tensors,
1808       const std::vector<uint32_t>& xnnpack_tensors) {
1809     TF_LITE_ENSURE_STATUS(
1810         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1811 
1812     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1813     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1814         logging_context, input_tensor, node->inputs->data[0], node_index));
1815     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1816         logging_context, input_tensor, node->inputs->data[0], node_index));
1817 
1818     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1819     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1820         logging_context, output_tensor, node->outputs->data[0], node_index));
1821     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1822         logging_context, output_tensor, node->outputs->data[0], node_index));
1823 
1824     if (subgraph != nullptr) {
1825       const xnn_status status = xnn_define_sigmoid(
1826           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1827           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1828       if (status != xnn_status_success) {
1829         TF_LITE_KERNEL_LOG(logging_context,
1830                            "failed to delegate LOGISTIC node #%d", node_index);
1831         return kTfLiteError;
1832       }
1833     }
1834 
1835     return kTfLiteOk;
1836   }
1837 
VisitMaxPool2DNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)1838   static TfLiteStatus VisitMaxPool2DNode(
1839       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1840       TfLiteNode* node, const TfLiteTensor* tensors,
1841       const TfLitePoolParams* pool_params,
1842       const std::vector<uint32_t>& xnnpack_tensors) {
1843     TF_LITE_ENSURE_STATUS(
1844         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
1845 
1846     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1847     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1848         logging_context, input_tensor, node->inputs->data[0], node_index));
1849     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1850         logging_context, input_tensor, node->inputs->data[0], node_index));
1851 
1852     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1853     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1854         logging_context, output_tensor, node->outputs->data[0], node_index));
1855     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1856         logging_context, output_tensor, node->outputs->data[0], node_index));
1857 
1858     TF_LITE_ENSURE_STATUS(
1859         CheckPoolingParams(logging_context, pool_params, node_index));
1860 
1861     uint32_t flags = 0;
1862     TF_LITE_ENSURE_STATUS(CalculatePadding(
1863         logging_context, pool_params->padding, &flags, node_index));
1864 
1865     float output_min = -std::numeric_limits<float>::infinity();
1866     float output_max = +std::numeric_limits<float>::infinity();
1867     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
1868         logging_context, node_index, pool_params->activation, &output_min,
1869         &output_max));
1870 
1871     if (subgraph != nullptr) {
1872       xnn_status status = xnn_status_success;
1873       if (pool_params->filter_height == 1 && pool_params->filter_width == 1) {
1874         status = xnn_define_clamp(
1875             subgraph, output_min, output_max,
1876             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1877             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1878       } else {
1879         status = xnn_define_max_pooling_2d(
1880             subgraph,
1881             /*input_padding_top=*/0,
1882             /*input_padding_right=*/0,
1883             /*input_padding_bottom=*/0,
1884             /*input_padding_left=*/0,
1885             static_cast<uint32_t>(pool_params->filter_height),
1886             static_cast<uint32_t>(pool_params->filter_width),
1887             static_cast<uint32_t>(pool_params->stride_height),
1888             static_cast<uint32_t>(pool_params->stride_width),
1889             /*dilation_height=*/1, /*dilation_width=*/1, output_min, output_max,
1890             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
1891             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
1892       }
1893       if (status != xnn_status_success) {
1894         TF_LITE_KERNEL_LOG(logging_context,
1895                            "failed to delegate MAX_POOL_2D node #%d",
1896                            node_index);
1897         return kTfLiteError;
1898       }
1899     }
1900 
1901     return kTfLiteOk;
1902   }
1903 
VisitMaximumNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)1904   static TfLiteStatus VisitMaximumNode(
1905       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1906       TfLiteNode* node, const TfLiteTensor* tensors,
1907       const std::vector<uint32_t>& xnnpack_tensors) {
1908     TF_LITE_ENSURE_STATUS(
1909         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
1910 
1911     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
1912     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1913         logging_context, input1_tensor, node->inputs->data[0], node_index));
1914     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1915         logging_context, input1_tensor, node->inputs->data[0], node_index));
1916 
1917     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
1918     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1919         logging_context, input2_tensor, node->inputs->data[1], node_index));
1920     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1921         logging_context, input2_tensor, node->inputs->data[1], node_index));
1922 
1923     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1924     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1925         logging_context, output_tensor, node->outputs->data[0], node_index));
1926     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1927         logging_context, output_tensor, node->outputs->data[0], node_index));
1928 
1929     if (subgraph != nullptr) {
1930       const xnn_status status = xnn_define_maximum2(
1931           subgraph, /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
1932           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
1933           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
1934       if (status != xnn_status_success) {
1935         TF_LITE_KERNEL_LOG(logging_context,
1936                            "failed to delegate MAXIMUM node #%d", node_index);
1937         return kTfLiteError;
1938       }
1939     }
1940 
1941     return kTfLiteOk;
1942   }
1943 
VisitMeanNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteReducerParams * reducer_params,const std::vector<uint32_t> & xnnpack_tensors)1944   static TfLiteStatus VisitMeanNode(
1945       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
1946       TfLiteNode* node, const TfLiteTensor* tensors,
1947       const TfLiteReducerParams* reducer_params,
1948       const std::vector<uint32_t>& xnnpack_tensors) {
1949     TF_LITE_ENSURE_STATUS(
1950         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
1951 
1952     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
1953     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1954         logging_context, input_tensor, node->inputs->data[0], node_index));
1955     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
1956                                            node->inputs->data[0]));
1957     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1958         logging_context, input_tensor, node->inputs->data[0], node_index));
1959 
1960     const TfLiteTensor& axes_tensor = tensors[node->inputs->data[1]];
1961     TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, axes_tensor,
1962                                           kTfLiteInt32, node->inputs->data[1],
1963                                           node_index));
1964     TF_LITE_ENSURE_STATUS(CheckAxesTensorShape(
1965         logging_context, axes_tensor, node->inputs->data[1], node_index));
1966     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
1967         logging_context, axes_tensor, node->inputs->data[1], node_index));
1968 
1969     if (axes_tensor.dims->data[0] != 2) {
1970       TF_LITE_MAYBE_KERNEL_LOG(
1971           logging_context,
1972           "unsupported MEAN reduction along %d axes in node %d",
1973           axes_tensor.dims->data[0], node_index);
1974       return kTfLiteError;
1975     }
1976 
1977     const int32_t* axes_data =
1978         reinterpret_cast<const int32_t*>(axes_tensor.data.data);
1979     if (std::min(axes_data[0], axes_data[1]) != 1 ||
1980         std::max(axes_data[0], axes_data[1]) != 2) {
1981       TF_LITE_MAYBE_KERNEL_LOG(logging_context,
1982                                "unsupported MEAN reduction along non-spatial "
1983                                "axes %d and %d in node %d",
1984                                std::min(axes_data[0], axes_data[1]),
1985                                std::max(axes_data[0], axes_data[1]),
1986                                node_index);
1987       return kTfLiteError;
1988     }
1989 
1990     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
1991     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
1992         logging_context, output_tensor, node->outputs->data[0], node_index));
1993     const int expected_output_dims = reducer_params->keep_dims ? 4 : 2;
1994     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor,
1995                                            expected_output_dims,
1996                                            node->outputs->data[0]));
1997     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
1998         logging_context, output_tensor, node->outputs->data[0], node_index));
1999 
2000     if (subgraph != nullptr) {
2001       const xnn_status status = xnn_define_global_average_pooling_2d(
2002           subgraph,
2003           /*output_min=*/-std::numeric_limits<float>::infinity(),
2004           /*output_max=*/+std::numeric_limits<float>::infinity(),
2005           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2006           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2007       if (status != xnn_status_success) {
2008         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate MEAN node #%d",
2009                            node_index);
2010         return kTfLiteError;
2011       }
2012     }
2013 
2014     return kTfLiteOk;
2015   }
2016 
VisitMediaPipeDeconvolutionNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteTransposeConvParams * deconv_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)2017   static TfLiteStatus VisitMediaPipeDeconvolutionNode(
2018       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2019       TfLiteNode* node, const TfLiteTensor* tensors,
2020       const TfLiteTransposeConvParams* deconv_params,
2021       const std::unordered_set<int>& quasi_static_tensors,
2022       const std::vector<uint32_t>& xnnpack_tensors) {
2023     TF_LITE_ENSURE_STATUS(
2024         CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
2025 
2026     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2027     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2028         logging_context, input_tensor, node->inputs->data[0], node_index));
2029     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
2030                                            node->inputs->data[0]));
2031     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2032         logging_context, input_tensor, node->inputs->data[0], node_index));
2033 
2034     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
2035     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2036         logging_context, filter_tensor, node->inputs->data[1], node_index));
2037     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
2038                                            node->inputs->data[1]));
2039     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
2040       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2041           logging_context, filter_tensor, node->inputs->data[1], node_index));
2042     }
2043 
2044     const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
2045     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2046         logging_context, bias_tensor, node->inputs->data[2], node_index));
2047     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
2048                                            node->inputs->data[2]));
2049     if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
2050       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2051           logging_context, bias_tensor, node->inputs->data[2], node_index));
2052     }
2053 
2054     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2055     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2056         logging_context, output_tensor, node->outputs->data[0], node_index));
2057     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
2058                                            node->outputs->data[0]));
2059     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2060         logging_context, output_tensor, node->outputs->data[0], node_index));
2061 
2062     const int output_channels = filter_tensor.dims->data[0];
2063     const int kernel_height = filter_tensor.dims->data[1];
2064     const int kernel_width = filter_tensor.dims->data[2];
2065     const int input_channels = filter_tensor.dims->data[3];
2066 
2067     TF_LITE_ENSURE_STATUS(CheckMediaPipeTransposedConvolutionParams(
2068         logging_context, deconv_params, node_index));
2069 
2070     uint32_t flags = 0;
2071     TF_LITE_ENSURE_STATUS(CalculatePadding(
2072         logging_context, deconv_params->padding, &flags, node_index));
2073 
2074     if (subgraph != nullptr) {
2075       const xnn_status status = xnn_define_deconvolution_2d(
2076           subgraph,
2077           /*padding_top=*/0,
2078           /*padding_right=*/0,
2079           /*padding_bottom=*/0,
2080           /*padding_left=*/0,
2081           /*adjustment_height=*/0,
2082           /*adjustment_width=*/0, static_cast<uint32_t>(kernel_height),
2083           static_cast<uint32_t>(kernel_width),
2084           static_cast<uint32_t>(deconv_params->stride_height),
2085           static_cast<uint32_t>(deconv_params->stride_width),
2086           /*dilation_height=*/1,
2087           /*dilation_width=*/1,
2088           /*groups=*/1,
2089           /*group_input_channels=*/input_channels,
2090           /*group_output_channels=*/output_channels,
2091           /*output_min=*/-std::numeric_limits<float>::infinity(),
2092           /*output_max=*/+std::numeric_limits<float>::infinity(),
2093           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2094           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
2095           /*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
2096           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
2097       if (status != xnn_status_success) {
2098         TF_LITE_KERNEL_LOG(
2099             logging_context,
2100             "failed to delegate Convolution2DTransposeBias node #%d",
2101             node_index);
2102         return kTfLiteError;
2103       }
2104     }
2105 
2106     return kTfLiteOk;
2107   }
2108 
VisitMediaPipeMaxPoolingNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)2109   static TfLiteStatus VisitMediaPipeMaxPoolingNode(
2110       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2111       TfLiteNode* node, const TfLiteTensor* tensors,
2112       const TfLitePoolParams* pool_params,
2113       const std::vector<uint32_t>& xnnpack_tensors) {
2114     TF_LITE_ENSURE_STATUS(
2115         CheckNumInputsAndOutputs(logging_context, node, 1, 2, node_index));
2116 
2117     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2118     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2119         logging_context, input_tensor, node->inputs->data[0], node_index));
2120     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
2121                                            node->inputs->data[0]));
2122     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2123         logging_context, input_tensor, node->inputs->data[0], node_index));
2124 
2125     const TfLiteTensor& output_value_tensor = tensors[node->outputs->data[0]];
2126     TF_LITE_ENSURE_STATUS(
2127         CheckTensorFloatType(logging_context, output_value_tensor,
2128                              node->outputs->data[0], node_index));
2129     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_value_tensor,
2130                                            4, node->outputs->data[0]));
2131     TF_LITE_ENSURE_STATUS(
2132         CheckTensorNonDynamicAllocation(logging_context, output_value_tensor,
2133                                         node->outputs->data[0], node_index));
2134 
2135     const TfLiteTensor& output_index_tensor = tensors[node->outputs->data[1]];
2136     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_index_tensor,
2137                                            4, node->outputs->data[1]));
2138     TF_LITE_ENSURE_STATUS(
2139         CheckTensorNonDynamicAllocation(logging_context, output_index_tensor,
2140                                         node->outputs->data[1], node_index));
2141 
2142     TF_LITE_ENSURE_STATUS(
2143         CheckMediaPipePoolParams(logging_context, pool_params, node_index));
2144 
2145     uint32_t flags = 0;
2146     TF_LITE_ENSURE_STATUS(CalculatePadding(
2147         logging_context, pool_params->padding, &flags, node_index));
2148 
2149     if (subgraph != nullptr) {
2150       const xnn_status status = xnn_define_argmax_pooling_2d(
2151           subgraph,
2152           /*input_padding_top=*/0,
2153           /*input_padding_right=*/0,
2154           /*input_padding_bottom=*/0,
2155           /*input_padding_left=*/0,
2156           static_cast<uint32_t>(pool_params->filter_height),
2157           static_cast<uint32_t>(pool_params->filter_width),
2158           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2159           /*output_value_id=*/xnnpack_tensors[node->outputs->data[0]],
2160           /*output_index_id=*/xnnpack_tensors[node->outputs->data[1]], flags);
2161       if (status != xnn_status_success) {
2162         TF_LITE_KERNEL_LOG(
2163             logging_context,
2164             "failed to delegate CUSTOM(MaxPoolingWithArgmax2D) node #%d",
2165             node_index);
2166         return kTfLiteError;
2167       }
2168     }
2169 
2170     return kTfLiteOk;
2171   }
2172 
VisitMediaPipeUnpoolingNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)2173   static TfLiteStatus VisitMediaPipeUnpoolingNode(
2174       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2175       TfLiteNode* node, const TfLiteTensor* tensors,
2176       const TfLitePoolParams* pool_params,
2177       const std::vector<uint32_t>& xnnpack_tensors) {
2178     TF_LITE_ENSURE_STATUS(
2179         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2180 
2181     const TfLiteTensor& input_value_tensor = tensors[node->inputs->data[0]];
2182     TF_LITE_ENSURE_STATUS(
2183         CheckTensorFloatType(logging_context, input_value_tensor,
2184                              node->inputs->data[0], node_index));
2185     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_value_tensor,
2186                                            4, node->inputs->data[0]));
2187     TF_LITE_ENSURE_STATUS(
2188         CheckTensorNonDynamicAllocation(logging_context, input_value_tensor,
2189                                         node->inputs->data[0], node_index));
2190 
2191     const TfLiteTensor& input_index_tensor = tensors[node->inputs->data[1]];
2192     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_index_tensor,
2193                                            4, node->inputs->data[1]));
2194     TF_LITE_ENSURE_STATUS(
2195         CheckTensorNonDynamicAllocation(logging_context, input_index_tensor,
2196                                         node->inputs->data[1], node_index));
2197 
2198     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2199     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2200         logging_context, output_tensor, node->outputs->data[0], node_index));
2201     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
2202                                            node->outputs->data[0]));
2203     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2204         logging_context, output_tensor, node->outputs->data[0], node_index));
2205 
2206     TF_LITE_ENSURE_STATUS(
2207         CheckMediaPipePoolParams(logging_context, pool_params, node_index));
2208 
2209     uint32_t flags = 0;
2210     TF_LITE_ENSURE_STATUS(CalculatePadding(
2211         logging_context, pool_params->padding, &flags, node_index));
2212     if (flags != 0) {
2213       TF_LITE_MAYBE_KERNEL_LOG(
2214           logging_context, "invalid padding mode (%d) in node #%d",
2215           static_cast<int>(pool_params->padding), node_index);
2216     }
2217 
2218     if (subgraph != nullptr) {
2219       const xnn_status status = xnn_define_unpooling_2d(
2220           subgraph,
2221           /*padding_top=*/0,
2222           /*padding_right=*/0,
2223           /*padding_bottom=*/0,
2224           /*padding_left=*/0, static_cast<uint32_t>(pool_params->filter_height),
2225           static_cast<uint32_t>(pool_params->filter_width),
2226           /*input_value_id=*/xnnpack_tensors[node->inputs->data[0]],
2227           /*input_index_id=*/xnnpack_tensors[node->inputs->data[1]],
2228           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
2229           /*flags=*/0);
2230       if (status != xnn_status_success) {
2231         TF_LITE_KERNEL_LOG(logging_context,
2232                            "failed to delegate CUSTOM(MaxUnpooling2D) node #%d",
2233                            node_index);
2234         return kTfLiteError;
2235       }
2236     }
2237 
2238     return kTfLiteOk;
2239   }
2240 
VisitMinimumNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2241   static TfLiteStatus VisitMinimumNode(
2242       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2243       TfLiteNode* node, const TfLiteTensor* tensors,
2244       const std::vector<uint32_t>& xnnpack_tensors) {
2245     TF_LITE_ENSURE_STATUS(
2246         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2247 
2248     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
2249     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2250         logging_context, input1_tensor, node->inputs->data[0], node_index));
2251     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2252         logging_context, input1_tensor, node->inputs->data[0], node_index));
2253 
2254     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
2255     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2256         logging_context, input2_tensor, node->inputs->data[1], node_index));
2257     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2258         logging_context, input2_tensor, node->inputs->data[1], node_index));
2259 
2260     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2261     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2262         logging_context, output_tensor, node->outputs->data[0], node_index));
2263     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2264         logging_context, output_tensor, node->outputs->data[0], node_index));
2265 
2266     if (subgraph != nullptr) {
2267       const xnn_status status = xnn_define_minimum2(
2268           subgraph, /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2269           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2270           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2271       if (status != xnn_status_success) {
2272         TF_LITE_KERNEL_LOG(logging_context,
2273                            "failed to delegate MINIMUM node #%d", node_index);
2274         return kTfLiteError;
2275       }
2276     }
2277 
2278     return kTfLiteOk;
2279   }
2280 
VisitMulNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteMulParams * mul_params,const std::vector<uint32_t> & xnnpack_tensors)2281   static TfLiteStatus VisitMulNode(
2282       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2283       TfLiteNode* node, const TfLiteTensor* tensors,
2284       const TfLiteMulParams* mul_params,
2285       const std::vector<uint32_t>& xnnpack_tensors) {
2286     TF_LITE_ENSURE_STATUS(
2287         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2288 
2289     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
2290     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2291         logging_context, input1_tensor, node->inputs->data[0], node_index));
2292     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2293         logging_context, input1_tensor, node->inputs->data[0], node_index));
2294 
2295     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
2296     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2297         logging_context, input2_tensor, node->inputs->data[1], node_index));
2298     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2299         logging_context, input2_tensor, node->inputs->data[1], node_index));
2300 
2301     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2302     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2303         logging_context, output_tensor, node->outputs->data[0], node_index));
2304     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2305         logging_context, output_tensor, node->outputs->data[0], node_index));
2306 
2307     float output_min = -std::numeric_limits<float>::infinity();
2308     float output_max = +std::numeric_limits<float>::infinity();
2309     if (mul_params != nullptr) {
2310       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
2311           logging_context, node_index, mul_params->activation, &output_min,
2312           &output_max));
2313     }
2314 
2315     if (subgraph != nullptr) {
2316       const xnn_status status = xnn_define_multiply2(
2317           subgraph, output_min, output_max,
2318           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2319           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2320           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2321       if (status != xnn_status_success) {
2322         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate MUL node #%d",
2323                            node_index);
2324         return kTfLiteError;
2325       }
2326     }
2327 
2328     return kTfLiteOk;
2329   }
2330 
VisitNegNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2331   static TfLiteStatus VisitNegNode(
2332       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2333       TfLiteNode* node, const TfLiteTensor* tensors,
2334       const std::vector<uint32_t>& xnnpack_tensors) {
2335     TF_LITE_ENSURE_STATUS(
2336         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2337 
2338     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2339     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2340         logging_context, input_tensor, node->inputs->data[0], node_index));
2341     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2342         logging_context, input_tensor, node->inputs->data[0], node_index));
2343 
2344     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2345     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2346         logging_context, output_tensor, node->outputs->data[0], node_index));
2347     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2348         logging_context, output_tensor, node->outputs->data[0], node_index));
2349 
2350     if (subgraph != nullptr) {
2351       const xnn_status status = xnn_define_negate(
2352           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2353           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2354       if (status != xnn_status_success) {
2355         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate NEG node #%d",
2356                            node_index);
2357         return kTfLiteError;
2358       }
2359     }
2360 
2361     return kTfLiteOk;
2362   }
2363 
VisitPadNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2364   static TfLiteStatus VisitPadNode(
2365       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2366       TfLiteNode* node, const TfLiteTensor* tensors,
2367       const std::vector<uint32_t>& xnnpack_tensors) {
2368     TF_LITE_ENSURE_STATUS(
2369         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2370 
2371     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2372     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2373         logging_context, input_tensor, node->inputs->data[0], node_index));
2374     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 1,
2375                                            XNN_MAX_TENSOR_DIMS,
2376                                            node->inputs->data[0]));
2377     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2378         logging_context, input_tensor, node->inputs->data[0], node_index));
2379 
2380     const TfLiteTensor& paddings_tensor = tensors[node->inputs->data[1]];
2381     TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, paddings_tensor,
2382                                           kTfLiteInt32, node->inputs->data[1],
2383                                           node_index));
2384     TF_LITE_ENSURE_STATUS(CheckPaddingsTensorShape(
2385         logging_context, paddings_tensor, input_tensor.dims->size,
2386         node->inputs->data[1], node_index));
2387     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2388         logging_context, paddings_tensor, node->inputs->data[1], node_index));
2389 
2390     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2391     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2392         logging_context, output_tensor, node->outputs->data[0], node_index));
2393     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 1,
2394                                            XNN_MAX_TENSOR_DIMS,
2395                                            node->outputs->data[0]));
2396     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2397         logging_context, output_tensor, node->outputs->data[0], node_index));
2398 
2399     const int32_t* paddings_data =
2400         reinterpret_cast<const int32_t*>(paddings_tensor.data.data);
2401     for (int i = 0; i < paddings_tensor.dims->size; i++) {
2402       const int32_t pre_padding = paddings_data[i * 2 + 0];
2403       if (pre_padding < 0) {
2404         TF_LITE_MAYBE_KERNEL_LOG(
2405             logging_context,
2406             "invalid pre-padding %d for dimension #%d in node %d", pre_padding,
2407             i, node_index);
2408         return kTfLiteError;
2409       }
2410 
2411       const int32_t post_padding = paddings_data[i * 2 + 1];
2412       if (post_padding < 0) {
2413         TF_LITE_MAYBE_KERNEL_LOG(
2414             logging_context,
2415             "invalid post-padding %d for dimension #%d in node %d", pre_padding,
2416             i, node_index);
2417         return kTfLiteError;
2418       }
2419     }
2420 
2421     if (subgraph != nullptr) {
2422       std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings{};
2423       std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings{};
2424       for (int i = 0; i < paddings_tensor.dims->data[0]; i++) {
2425         pre_paddings[i] = static_cast<size_t>(paddings_data[i * 2 + 0]);
2426         post_paddings[i] = static_cast<size_t>(paddings_data[i * 2 + 1]);
2427       }
2428 
2429       const xnn_status status = xnn_define_static_constant_pad(
2430           subgraph, pre_paddings.data(), post_paddings.data(),
2431           /*padding_value=*/0.0f,
2432           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2433           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2434       if (status != xnn_status_success) {
2435         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate PAD node #%d",
2436                            node_index);
2437         return kTfLiteError;
2438       }
2439     }
2440 
2441     return kTfLiteOk;
2442   }
2443 
VisitPreluNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)2444   static TfLiteStatus VisitPreluNode(
2445       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2446       TfLiteNode* node, const TfLiteTensor* tensors,
2447       const std::unordered_set<int>& quasi_static_tensors,
2448       const std::vector<uint32_t>& xnnpack_tensors) {
2449     TF_LITE_ENSURE_STATUS(
2450         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2451 
2452     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2453     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2454         logging_context, input_tensor, node->inputs->data[0], node_index));
2455     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 1,
2456                                            XNN_MAX_TENSOR_DIMS,
2457                                            node->inputs->data[0]));
2458     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2459         logging_context, input_tensor, node->inputs->data[0], node_index));
2460 
2461     const TfLiteTensor& slope_tensor = tensors[node->inputs->data[1]];
2462     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2463         logging_context, slope_tensor, node->inputs->data[1], node_index));
2464     TF_LITE_ENSURE_STATUS(CheckSlopeTensorShape(
2465         logging_context, slope_tensor, node->inputs->data[1], node_index));
2466     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
2467       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2468           logging_context, slope_tensor, node->inputs->data[1], node_index));
2469     }
2470 
2471     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2472     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2473         logging_context, output_tensor, node->outputs->data[0], node_index));
2474     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 1,
2475                                            XNN_MAX_TENSOR_DIMS,
2476                                            node->outputs->data[0]));
2477     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2478         logging_context, output_tensor, node->outputs->data[0], node_index));
2479 
2480     if (subgraph != nullptr) {
2481       const xnn_status status = xnn_define_prelu(
2482           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2483           /*slope_id=*/xnnpack_tensors[node->inputs->data[1]],
2484           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2485       if (status != xnn_status_success) {
2486         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate PRELU node #%d",
2487                            node_index);
2488         return kTfLiteError;
2489       }
2490     }
2491 
2492     return kTfLiteOk;
2493   }
2494 
VisitReluNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,float output_min,float output_max,const std::vector<uint32_t> & xnnpack_tensors)2495   static TfLiteStatus VisitReluNode(
2496       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2497       TfLiteNode* node, const TfLiteTensor* tensors, float output_min,
2498       float output_max, const std::vector<uint32_t>& xnnpack_tensors) {
2499     TF_LITE_ENSURE_STATUS(
2500         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2501 
2502     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2503     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2504         logging_context, input_tensor, node->inputs->data[0], node_index));
2505     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2506         logging_context, input_tensor, node->inputs->data[0], node_index));
2507 
2508     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2509     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2510         logging_context, output_tensor, node->outputs->data[0], node_index));
2511     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2512         logging_context, output_tensor, node->outputs->data[0], node_index));
2513 
2514     if (subgraph != nullptr) {
2515       const xnn_status status = xnn_define_clamp(
2516           subgraph, output_min, output_max,
2517           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2518           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2519       if (status != xnn_status_success) {
2520         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate RELU node #%d",
2521                            node_index);
2522         return kTfLiteError;
2523       }
2524     }
2525 
2526     return kTfLiteOk;
2527   }
2528 
VisitReshapeNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteReshapeParams * reshape_params,const std::vector<uint32_t> & xnnpack_tensors)2529   static TfLiteStatus VisitReshapeNode(
2530       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2531       TfLiteNode* node, const TfLiteTensor* tensors,
2532       const TfLiteReshapeParams* reshape_params,
2533       const std::vector<uint32_t>& xnnpack_tensors) {
2534     switch (node->inputs->size) {
2535       case 1:
2536       case 2:
2537         break;
2538       default:
2539         TF_LITE_MAYBE_KERNEL_LOG(
2540             logging_context,
2541             "unexpected number of inputs (%d) in node #%d: "
2542             "either one or two inputs expected",
2543             node->inputs->size, node_index);
2544         return kTfLiteError;
2545     }
2546     if (node->outputs->size != 1) {
2547       TF_LITE_MAYBE_KERNEL_LOG(
2548           logging_context,
2549           "unexpected number of outputs (%d) in node #%d: one output expected",
2550           node->outputs->size, node_index);
2551       return kTfLiteError;
2552     }
2553 
2554     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2555     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2556         logging_context, input_tensor, node->inputs->data[0], node_index));
2557     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 0,
2558                                            XNN_MAX_TENSOR_DIMS,
2559                                            node->inputs->data[0]));
2560     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2561         logging_context, input_tensor, node->inputs->data[0], node_index));
2562 
2563     if (node->inputs->size == 2) {
2564       const TfLiteTensor& shape_tensor = tensors[node->inputs->data[1]];
2565       TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, shape_tensor,
2566                                             kTfLiteInt32, node->inputs->data[1],
2567                                             node_index));
2568       TF_LITE_ENSURE_STATUS(CheckShapeTensorShape(
2569           logging_context, shape_tensor, node->inputs->data[1], node_index));
2570       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2571           logging_context, shape_tensor, node->inputs->data[1], node_index));
2572     }
2573 
2574     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2575     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2576         logging_context, output_tensor, node->outputs->data[0], node_index));
2577     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 0,
2578                                            XNN_MAX_TENSOR_DIMS,
2579                                            node->outputs->data[0]));
2580     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2581         logging_context, output_tensor, node->outputs->data[0], node_index));
2582 
2583     if (subgraph != nullptr) {
2584       std::array<size_t, XNN_MAX_TENSOR_DIMS> new_shape;
2585       std::copy(&output_tensor.dims->data[0],
2586                 &output_tensor.dims->data[output_tensor.dims->size],
2587                 new_shape.begin());
2588       const xnn_status status = xnn_define_static_reshape(
2589           subgraph, static_cast<size_t>(output_tensor.dims->size),
2590           new_shape.data(),
2591           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2592           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2593       if (status != xnn_status_success) {
2594         TF_LITE_KERNEL_LOG(logging_context,
2595                            "failed to delegate RESHAPE node #%d", node_index);
2596         return kTfLiteError;
2597       }
2598     }
2599 
2600     return kTfLiteOk;
2601   }
2602 
VisitResizeBilinearNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteResizeBilinearParams * resize_params,const std::vector<uint32_t> & xnnpack_tensors)2603   static TfLiteStatus VisitResizeBilinearNode(
2604       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2605       TfLiteNode* node, const TfLiteTensor* tensors,
2606       const TfLiteResizeBilinearParams* resize_params,
2607       const std::vector<uint32_t>& xnnpack_tensors) {
2608     TF_LITE_ENSURE_STATUS(
2609         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2610 
2611     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2612     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2613         logging_context, input_tensor, node->inputs->data[0], node_index));
2614     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
2615                                            node->inputs->data[0]));
2616     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2617         logging_context, input_tensor, node->inputs->data[0], node_index));
2618 
2619     const TfLiteTensor& shape_tensor = tensors[node->inputs->data[1]];
2620     TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, shape_tensor,
2621                                           kTfLiteInt32, node->inputs->data[1],
2622                                           node_index));
2623     TF_LITE_ENSURE_STATUS(CheckShapeTensorShape(
2624         logging_context, shape_tensor, node->inputs->data[1], node_index));
2625     if (shape_tensor.dims->data[0] != 2) {
2626       TF_LITE_MAYBE_KERNEL_LOG(
2627           logging_context,
2628           "unexpected number of dimensions %d in the output shape in node %d",
2629           shape_tensor.dims->data[0], node_index);
2630     }
2631     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2632         logging_context, shape_tensor, node->inputs->data[1], node_index));
2633 
2634     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2635     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2636         logging_context, output_tensor, node->outputs->data[0], node_index));
2637     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
2638                                            node->outputs->data[0]));
2639     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2640         logging_context, output_tensor, node->outputs->data[0], node_index));
2641 
2642     const int32_t* shape_data =
2643         reinterpret_cast<const int32_t*>(shape_tensor.data.data);
2644     for (int i = 0; i < shape_tensor.dims->size; i++) {
2645       const int32_t dim = shape_data[i];
2646       if (dim <= 0) {
2647         TF_LITE_MAYBE_KERNEL_LOG(
2648             logging_context, "invalid output dimension #%d value %d in node %d",
2649             i, dim, node_index);
2650         return kTfLiteError;
2651       }
2652     }
2653 
2654     if (subgraph != nullptr) {
2655       uint32_t flags = 0;
2656       if (resize_params->align_corners) {
2657         flags |= XNN_FLAG_ALIGN_CORNERS;
2658       } else if (!resize_params->half_pixel_centers) {
2659         flags |= XNN_FLAG_TENSORFLOW_LEGACY_MODE;
2660       }
2661       const xnn_status status = xnn_define_static_resize_bilinear_2d(
2662           subgraph, static_cast<size_t>(shape_data[0]),
2663           static_cast<size_t>(shape_data[1]),
2664           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2665           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
2666       if (status != xnn_status_success) {
2667         TF_LITE_KERNEL_LOG(logging_context,
2668                            "failed to delegate RESIZE_BILINEAR node #%d",
2669                            node_index);
2670         return kTfLiteError;
2671       }
2672     }
2673 
2674     return kTfLiteOk;
2675   }
2676 
VisitRoundNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2677   static TfLiteStatus VisitRoundNode(
2678       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2679       TfLiteNode* node, const TfLiteTensor* tensors,
2680       const std::vector<uint32_t>& xnnpack_tensors) {
2681     TF_LITE_ENSURE_STATUS(
2682         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2683 
2684     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2685     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2686         logging_context, input_tensor, node->inputs->data[0], node_index));
2687     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2688         logging_context, input_tensor, node->inputs->data[0], node_index));
2689 
2690     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2691     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2692         logging_context, output_tensor, node->outputs->data[0], node_index));
2693     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2694         logging_context, output_tensor, node->outputs->data[0], node_index));
2695 
2696     if (subgraph != nullptr) {
2697       const xnn_status status = xnn_define_bankers_rounding(
2698           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2699           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2700       if (status != xnn_status_success) {
2701         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ROUND node #%d",
2702                            node_index);
2703         return kTfLiteError;
2704       }
2705     }
2706 
2707     return kTfLiteOk;
2708   }
2709 
VisitSoftmaxNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteSoftmaxParams * params,const std::vector<uint32_t> & xnnpack_tensors)2710   static TfLiteStatus VisitSoftmaxNode(
2711       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2712       TfLiteNode* node, const TfLiteTensor* tensors,
2713       const TfLiteSoftmaxParams* params,
2714       const std::vector<uint32_t>& xnnpack_tensors) {
2715     if (params->beta != 1.0f) {
2716       if (logging_context != nullptr) {
2717         TF_LITE_KERNEL_LOG(logging_context,
2718                            "unsupported beta value %.7f in SOFTMAX node #%d",
2719                            params->beta, node_index);
2720       }
2721       return kTfLiteError;
2722     }
2723 
2724     TF_LITE_ENSURE_STATUS(
2725         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2726 
2727     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2728     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2729         logging_context, input_tensor, node->inputs->data[0], node_index));
2730     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2731         logging_context, input_tensor, node->inputs->data[0], node_index));
2732 
2733     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2734     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2735         logging_context, output_tensor, node->outputs->data[0], node_index));
2736     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2737         logging_context, output_tensor, node->outputs->data[0], node_index));
2738 
2739     if (subgraph != nullptr) {
2740       const xnn_status status = xnn_define_softmax(
2741           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2742           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2743       if (status != xnn_status_success) {
2744         TF_LITE_KERNEL_LOG(logging_context,
2745                            "failed to delegate SOFTMAX node #%d", node_index);
2746         return kTfLiteError;
2747       }
2748     }
2749 
2750     return kTfLiteOk;
2751   }
2752 
VisitSquareNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2753   static TfLiteStatus VisitSquareNode(
2754       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2755       TfLiteNode* node, const TfLiteTensor* tensors,
2756       const std::vector<uint32_t>& xnnpack_tensors) {
2757     TF_LITE_ENSURE_STATUS(
2758         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2759 
2760     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2761     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2762         logging_context, input_tensor, node->inputs->data[0], node_index));
2763     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2764         logging_context, input_tensor, node->inputs->data[0], node_index));
2765 
2766     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2767     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2768         logging_context, output_tensor, node->outputs->data[0], node_index));
2769     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2770         logging_context, output_tensor, node->outputs->data[0], node_index));
2771 
2772     if (subgraph != nullptr) {
2773       const xnn_status status = xnn_define_square(
2774           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2775           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2776       if (status != xnn_status_success) {
2777         TF_LITE_KERNEL_LOG(logging_context,
2778                            "failed to delegate SQUARE node #%d", node_index);
2779         return kTfLiteError;
2780       }
2781     }
2782 
2783     return kTfLiteOk;
2784   }
2785 
VisitSqrtNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2786   static TfLiteStatus VisitSqrtNode(
2787       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2788       TfLiteNode* node, const TfLiteTensor* tensors,
2789       const std::vector<uint32_t>& xnnpack_tensors) {
2790     TF_LITE_ENSURE_STATUS(
2791         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2792 
2793     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2794     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2795         logging_context, input_tensor, node->inputs->data[0], node_index));
2796     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2797         logging_context, input_tensor, node->inputs->data[0], node_index));
2798 
2799     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2800     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2801         logging_context, output_tensor, node->outputs->data[0], node_index));
2802     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2803         logging_context, output_tensor, node->outputs->data[0], node_index));
2804 
2805     if (subgraph != nullptr) {
2806       const xnn_status status = xnn_define_square_root(
2807           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2808           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2809       if (status != xnn_status_success) {
2810         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate SQRT node #%d",
2811                            node_index);
2812         return kTfLiteError;
2813       }
2814     }
2815 
2816     return kTfLiteOk;
2817   }
2818 
VisitSquaredDifferenceNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2819   static TfLiteStatus VisitSquaredDifferenceNode(
2820       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2821       TfLiteNode* node, const TfLiteTensor* tensors,
2822       const std::vector<uint32_t>& xnnpack_tensors) {
2823     TF_LITE_ENSURE_STATUS(
2824         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2825 
2826     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
2827     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2828         logging_context, input1_tensor, node->inputs->data[0], node_index));
2829     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2830         logging_context, input1_tensor, node->inputs->data[0], node_index));
2831 
2832     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
2833     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2834         logging_context, input2_tensor, node->inputs->data[1], node_index));
2835     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2836         logging_context, input2_tensor, node->inputs->data[1], node_index));
2837 
2838     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2839     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2840         logging_context, output_tensor, node->outputs->data[0], node_index));
2841     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2842         logging_context, output_tensor, node->outputs->data[0], node_index));
2843 
2844     if (subgraph != nullptr) {
2845       const xnn_status status = xnn_define_squared_difference(
2846           subgraph, /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2847           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2848           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2849       if (status != xnn_status_success) {
2850         TF_LITE_KERNEL_LOG(logging_context,
2851                            "failed to delegate SQUARED_DIFFERENCE node #%d",
2852                            node_index);
2853         return kTfLiteError;
2854       }
2855     }
2856 
2857     return kTfLiteOk;
2858   }
2859 
VisitSubNode(xnn_subgraph_t subgraph,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteSubParams * sub_params,const std::vector<uint32_t> & xnnpack_tensors)2860   static TfLiteStatus VisitSubNode(
2861       xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
2862       TfLiteNode* node, const TfLiteTensor* tensors,
2863       const TfLiteSubParams* sub_params,
2864       const std::vector<uint32_t>& xnnpack_tensors) {
2865     TF_LITE_ENSURE_STATUS(
2866         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2867 
2868     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
2869     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2870         logging_context, input1_tensor, node->inputs->data[0], node_index));
2871     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2872         logging_context, input1_tensor, node->inputs->data[0], node_index));
2873 
2874     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
2875     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2876         logging_context, input2_tensor, node->inputs->data[1], node_index));
2877     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2878         logging_context, input2_tensor, node->inputs->data[1], node_index));
2879 
2880     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2881     TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
2882         logging_context, output_tensor, node->outputs->data[0], node_index));
2883     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2884         logging_context, output_tensor, node->outputs->data[0], node_index));
2885 
2886     float output_min = -std::numeric_limits<float>::infinity();
2887     float output_max = +std::numeric_limits<float>::infinity();
2888     if (sub_params != nullptr) {
2889       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
2890           logging_context, node_index, sub_params->activation, &output_min,
2891           &output_max));
2892     }
2893 
2894     if (subgraph != nullptr) {
2895       const xnn_status status = xnn_define_subtract(
2896           subgraph, output_min, output_max,
2897           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2898           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2899           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2900       if (status != xnn_status_success) {
2901         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate SUB node #%d",
2902                            node_index);
2903         return kTfLiteError;
2904       }
2905     }
2906 
2907     return kTfLiteOk;
2908   }
2909 
2910  private:
Subgraph(xnn_runtime_t runtime,std::unordered_set<int> && externals)2911   Subgraph(xnn_runtime_t runtime, std::unordered_set<int>&& externals)
2912       : runtime_(runtime, &xnn_delete_runtime), externals_(externals) {}
2913 
2914   // XNNPACK Runtime (subgraph + workspace) with smart-pointer for lifetime
2915   // management.
2916   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{
2917       nullptr, &xnn_delete_runtime};
2918   // TFLite Tensor IDs == XNNPACK Value IDs of input/output tensors for the
2919   // delegated subgraph.
2920   std::unordered_set<int> externals_;
2921   bool first_run_{true};
2922 };
2923 
PrepareOpsToDelegate(TfLiteContext * context)2924 TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) {
2925   // Clear previous data, in case the delegate is reused without re-creation.
2926   static_unpacked_data_map_.clear();
2927   static_unpacked_data_.clear();
2928   static_unpack_nodes_.clear();
2929   static_sparse_weights_.clear();
2930 
2931   TfLiteIntArray* execution_plan = nullptr;
2932   if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
2933     TF_LITE_KERNEL_LOG(context, "Unable to get graph execution plan.");
2934     return nullptr;
2935   }
2936 
2937   // Mapping for quasi-static (unpacked from static) tensor index to the node
2938   // index that produced it.
2939   std::unordered_map<int, int> quasi_static_tensors_producers;
2940   // Set of all quasi-static tensors in the execution plan.
2941   std::unordered_set<int> quasi_static_tensors;
2942   // Set of quasi-static tensors consumed by the delegated nodes.
2943   std::unordered_set<int> quasi_static_tensors_to_unpack;
2944 
2945   TfLiteIntArray* nodes_to_delegate =
2946       TfLiteIntArrayCreate(execution_plan->size);
2947   nodes_to_delegate->size = 0;
2948   for (int i = 0; i < execution_plan->size; ++i) {
2949     const int node_index = execution_plan->data[i];
2950 
2951     // Check if TFLite nodes can be delegated to XNNPACK
2952     TfLiteNode* node = nullptr;
2953     TfLiteRegistration* registration = nullptr;
2954     if (context->GetNodeAndRegistration(context, node_index, &node,
2955                                         &registration) != kTfLiteOk) {
2956       TF_LITE_KERNEL_LOG(context,
2957                          "Unable to get node and registration for node %d.",
2958                          node_index);
2959       continue;  // Soft error (skip this node).
2960     }
2961 
2962     // Prepare to unpack FP16 tensors.
2963     if (registration->builtin_code == kTfLiteBuiltinDequantize &&
2964         node->inputs->size == 1 && node->outputs->size == 1) {
2965       const TfLiteTensor& input_tensor =
2966           context->tensors[node->inputs->data[0]];
2967       const TfLiteTensor& output_tensor =
2968           context->tensors[node->outputs->data[0]];
2969       if ((input_tensor.allocation_type == kTfLiteMmapRo ||
2970            quasi_static_tensors.count(node->inputs->data[0]) != 0) &&
2971           input_tensor.type == kTfLiteFloat16 &&
2972           output_tensor.type == kTfLiteFloat32) {
2973         static_unpack_nodes_.insert(node_index);
2974         quasi_static_tensors_producers[node->outputs->data[0]] = node_index;
2975         quasi_static_tensors.insert(node->outputs->data[0]);
2976 
2977         if (input_tensor.allocation_type != kTfLiteMmapRo) {
2978           quasi_static_tensors_to_unpack.insert(node->inputs->data[0]);
2979         }
2980 
2981         // If dequantized input is sparse, so is its output
2982         if (static_sparse_weights_.count(node->inputs->data[0]) != 0) {
2983           static_sparse_weights_.insert(node->outputs->data[0]);
2984         }
2985 
2986         // Skip this node for now. If output of the node is consumed only by
2987         // delegated nodes, it will be added to nodes_to_delegate in the end.
2988         continue;
2989       }
2990     }
2991 
2992     // Prepare to unpack sparse tensors.
2993     // TODO(b/157729695): In the future, we also need to handle the case where a
2994     // sparse tensor is fed to a TFLite op directly, and no Densify() op is
2995     // inserted. For now this is not a problem because the Conv() op in tflite
2996     // can only consume dense tensors.
2997     if (registration->builtin_code == kTfLiteBuiltinDensify &&
2998         node->inputs->size == 1 && node->outputs->size == 1) {
2999       const TfLiteTensor& input_tensor =
3000           context->tensors[node->inputs->data[0]];
3001       const TfLiteTensor& output_tensor =
3002           context->tensors[node->outputs->data[0]];
3003       if (input_tensor.allocation_type == kTfLiteMmapRo &&
3004           input_tensor.sparsity != nullptr &&
3005           (input_tensor.type == kTfLiteFloat16 ||
3006            input_tensor.type == kTfLiteFloat32) &&
3007           output_tensor.type == input_tensor.type) {
3008         static_unpack_nodes_.insert(node_index);
3009         quasi_static_tensors_producers[node->outputs->data[0]] = node_index;
3010         quasi_static_tensors.insert(node->outputs->data[0]);
3011         static_sparse_weights_.insert(node->outputs->data[0]);
3012 
3013         // Skip this node for now. If output of the node is consumed only by
3014         // delegated nodes, it will be added to nodes_to_delegate in the end.
3015         continue;
3016       }
3017     }
3018 
3019     if (Subgraph::VisitNode(/*subgraph=*/nullptr, context, registration, node,
3020                             node_index, quasi_static_tensors,
3021                             std::vector<uint32_t>()) != kTfLiteOk) {
3022       // If a non-delegated node consumes output of a node that unpacks static
3023       // data, that node shouldn't be delegated.
3024       for (int j = 0; j < node->inputs->size; j++) {
3025         const auto it =
3026             quasi_static_tensors_producers.find(node->inputs->data[j]);
3027         if (it != quasi_static_tensors_producers.end()) {
3028           static_unpack_nodes_.erase(it->second);
3029         }
3030       }
3031 
3032       // Non-delegatable node is not an error.
3033       continue;
3034     }
3035 
3036     for (int j = 0; j < node->inputs->size; j++) {
3037       if (quasi_static_tensors.count(node->inputs->data[j]) != 0) {
3038         quasi_static_tensors_to_unpack.insert(node->inputs->data[j]);
3039       }
3040     }
3041 
3042     nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
3043   }
3044 
3045   // Sort quasi-static tensors to be unpacked by the node index the produced
3046   // them. This ensures that in situations where quasi-static tensor is
3047   // produced from another quasi-static tensor, the tensors are unpacked in
3048   // the original execution plan order.
3049   std::vector<int> sorted_quasi_static_tensors_to_unpack(
3050       quasi_static_tensors_to_unpack.cbegin(),
3051       quasi_static_tensors_to_unpack.cend());
3052   std::sort(sorted_quasi_static_tensors_to_unpack.begin(),
3053             sorted_quasi_static_tensors_to_unpack.end(),
3054             [&quasi_static_tensors_producers](int t1, int t2) {
3055               return quasi_static_tensors_producers[t1] <
3056                      quasi_static_tensors_producers[t2];
3057             });
3058 
3059   // Unpack static data of all tensors
3060   for (int t : sorted_quasi_static_tensors_to_unpack) {
3061     const int producer_index = quasi_static_tensors_producers[t];
3062     // Check if TFLite nodes can be delegated to XNNPACK
3063     TfLiteNode* node = nullptr;
3064     TfLiteRegistration* registration = nullptr;
3065     if (context->GetNodeAndRegistration(context, producer_index, &node,
3066                                         &registration) != kTfLiteOk) {
3067       TF_LITE_KERNEL_LOG(context,
3068                          "Unable to get node and registration for node %d.",
3069                          producer_index);
3070       TfLiteIntArrayFree(nodes_to_delegate);
3071       return nullptr;  // Hard error.
3072     }
3073 
3074     if (node->inputs->size != 1) {
3075       TF_LITE_KERNEL_LOG(context, "unexpected number of inputs (%d) in node %d",
3076                          node->inputs->size, producer_index);
3077       TfLiteIntArrayFree(nodes_to_delegate);
3078       return nullptr;  // Hard error.
3079     }
3080 
3081     if (node->outputs->size != 1) {
3082       TF_LITE_KERNEL_LOG(context,
3083                          "unexpected number of outputs (%d) in node %d",
3084                          node->outputs->size, producer_index);
3085       TfLiteIntArrayFree(nodes_to_delegate);
3086       return nullptr;  // Hard error.
3087     }
3088 
3089     const TfLiteTensor& input_tensor = context->tensors[node->inputs->data[0]];
3090 
3091     // Consider the case when the input to unpacking node is quasi-static.
3092     const auto static_unpacked_input_it_ =
3093         static_unpacked_data_map_.find(node->inputs->data[0]);
3094     if (static_unpacked_input_it_ == static_unpacked_data_map_.end()) {
3095       if (input_tensor.allocation_type != kTfLiteMmapRo) {
3096         TF_LITE_KERNEL_LOG(
3097             context,
3098             "unexpected allocation type (%d) in tensor %d in node %d (%d)",
3099             input_tensor.allocation_type, node->inputs->data[0], producer_index,
3100             registration->builtin_code);
3101         TfLiteIntArrayFree(nodes_to_delegate);
3102         return nullptr;  // Hard error.
3103       }
3104     }
3105 
3106     const TfLiteTensor& output_tensor = context->tensors[t];
3107     size_t tensor_elements = output_tensor.bytes;
3108     switch (output_tensor.type) {
3109       case kTfLiteFloat32:
3110         tensor_elements /= sizeof(float);
3111         break;
3112       case kTfLiteFloat16:
3113         tensor_elements /= sizeof(uint16_t);
3114         break;
3115       default: {
3116         TF_LITE_KERNEL_LOG(context,
3117                            "unexpected datatype (%s) in tensor %d in node %d",
3118                            TfLiteTypeGetName(output_tensor.type),
3119                            node->outputs->data[0], producer_index);
3120         TfLiteIntArrayFree(nodes_to_delegate);
3121         return nullptr;  // Hard error.
3122       }
3123     }
3124 
3125     // Align to XNN_EXTRA_BYTES bytes
3126     while (static_unpacked_data_.size() % XNN_EXTRA_BYTES != 0) {
3127       static_unpacked_data_.push_back(0);
3128     }
3129     const size_t tensor_offset = static_unpacked_data_.size();
3130     static_unpacked_data_.resize(tensor_offset + context->tensors[t].bytes);
3131 
3132     char* unpacked_data = static_unpacked_data_.data() + tensor_offset;
3133     const char* packed_data =
3134         static_unpacked_input_it_ != static_unpacked_data_map_.end()
3135             ? static_unpacked_data_.data() + static_unpacked_input_it_->second
3136             : static_cast<const char*>(input_tensor.data.data);
3137     switch (registration->builtin_code) {
3138       case kTfLiteBuiltinDequantize: {
3139         if (input_tensor.type != kTfLiteFloat16) {
3140           TF_LITE_KERNEL_LOG(
3141               context, "unexpected tensor %d data type (%s) in node %d",
3142               node->inputs->data[0], TfLiteTypeGetName(input_tensor.type),
3143               producer_index);
3144           TfLiteIntArrayFree(nodes_to_delegate);
3145           return nullptr;  // Hard error.
3146         }
3147 
3148         if (input_tensor.sparsity != nullptr) {
3149           TF_LITE_KERNEL_LOG(context,
3150                              "unexpected FP16 sparse tensor %d in node %d",
3151                              node->inputs->data[0], producer_index);
3152           TfLiteIntArrayFree(nodes_to_delegate);
3153           return nullptr;  // Hard error.
3154         }
3155 
3156         // Actual data unpacking
3157         float* unpacked_fp32_data = reinterpret_cast<float*>(unpacked_data);
3158         const uint16_t* packed_fp16_data =
3159             reinterpret_cast<const uint16_t*>(packed_data);
3160         for (size_t i = 0; i < tensor_elements; i++) {
3161           unpacked_fp32_data[i] = fp16_ieee_to_fp32_value(packed_fp16_data[i]);
3162         }
3163         break;
3164       }
3165       case kTfLiteBuiltinDensify: {
3166         if (input_tensor.sparsity == nullptr) {
3167           TF_LITE_KERNEL_LOG(context, "unexpected dense tensor %d in node %d",
3168                              node->inputs->data[0], producer_index);
3169           TfLiteIntArrayFree(nodes_to_delegate);
3170           return nullptr;  // Hard error.
3171         }
3172 
3173         const int dims_count = output_tensor.dims->size;
3174         std::vector<int> vector_shape(dims_count);
3175         for (int i = 0; i < dims_count; i++) {
3176           vector_shape[i] = output_tensor.dims->data[i];
3177         }
3178 
3179         switch (input_tensor.type) {
3180           case kTfLiteFloat32: {
3181             const size_t dense_size = context->tensors[t].bytes / sizeof(float);
3182             float* unpacked_fp32_data = reinterpret_cast<float*>(unpacked_data);
3183             tflite::optimize::sparsity::FormatConverter<float> converter(
3184                 vector_shape, *input_tensor.sparsity);
3185             converter.SparseToDense(
3186                 static_cast<const float*>(input_tensor.data.data), dense_size,
3187                 unpacked_fp32_data, context);
3188             break;
3189           }
3190           case kTfLiteFloat16: {
3191             const size_t dense_size =
3192                 context->tensors[t].bytes / sizeof(Eigen::half);
3193             Eigen::half* unpacked_fp16_data =
3194                 reinterpret_cast<Eigen::half*>(unpacked_data);
3195             tflite::optimize::sparsity::FormatConverter<Eigen::half> converter(
3196                 vector_shape, *input_tensor.sparsity);
3197             converter.SparseToDense(
3198                 static_cast<const Eigen::half*>(input_tensor.data.data),
3199                 dense_size, unpacked_fp16_data, context);
3200             break;
3201           }
3202           default: {
3203             TF_LITE_KERNEL_LOG(
3204                 context, "unexpected tensor %d data type (%s) in node %d",
3205                 node->inputs->data[0], TfLiteTypeGetName(input_tensor.type),
3206                 producer_index);
3207             TfLiteIntArrayFree(nodes_to_delegate);
3208             return nullptr;  // Hard error.
3209           }
3210         }
3211         break;
3212       }
3213       default:
3214         TF_LITE_KERNEL_LOG(context, "unexpected op registration %d at node %d",
3215                            registration->builtin_code, producer_index);
3216         TfLiteIntArrayFree(nodes_to_delegate);
3217         return nullptr;  // Hard error.
3218     }
3219 
3220     static_unpacked_data_map_[t] = tensor_offset;
3221   }
3222 
3223   // Add nodes that unpack static data consumed by delegated nodes.
3224   // Note: this is done purely to avoid the overhead of running these nodes
3225   // again in TFLite interpreter which would allocate memory for their outputs.
3226   // We mark them as delegated, but the delegate would simply ignore these nodes
3227   // as the static weights are already unpacked.
3228   for (int node_index : static_unpack_nodes_) {
3229     nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
3230   }
3231   std::sort(&nodes_to_delegate->data[0],
3232             &nodes_to_delegate->data[nodes_to_delegate->size]);
3233 
3234 #ifdef XNNPACK_DELEGATE_TEST_MODE
3235   // In the test mode build (used by unit tests), XNNPACK delegate claims to
3236   // support all operators in the execution plan to disable fallback to the
3237   // default TensorFlow Lite kernels. Thus, if any of the ops in the model are
3238   // not supported by the delegate, they will cause a failure in
3239   // ::tflite::Interpreter::ModifyGraphWithDelegate, to be caught in the unit
3240   // tests.
3241   nodes_to_delegate->size = execution_plan->size;
3242   std::copy(&execution_plan->data[0],
3243             &execution_plan->data[execution_plan->size],
3244             &nodes_to_delegate->data[0]);
3245 #endif
3246 
3247   return nodes_to_delegate;
3248 }
3249 
SubgraphInit(TfLiteContext * context,const char * buffer,size_t length)3250 void* SubgraphInit(TfLiteContext* context, const char* buffer, size_t length) {
3251   const TfLiteDelegateParams* params =
3252       reinterpret_cast<const TfLiteDelegateParams*>(buffer);
3253 
3254   return static_cast<void*>(Subgraph::Create(
3255       context, params,
3256       static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_)));
3257 }
3258 
SubgraphPrepare(TfLiteContext * context,TfLiteNode * node)3259 TfLiteStatus SubgraphPrepare(TfLiteContext* context, TfLiteNode* node) {
3260   if (node->user_data == nullptr) {
3261     return kTfLiteError;
3262   }
3263 
3264   return static_cast<Subgraph*>(node->user_data)->Prepare(context);
3265 }
3266 
SubgraphInvoke(TfLiteContext * context,TfLiteNode * node)3267 TfLiteStatus SubgraphInvoke(TfLiteContext* context, TfLiteNode* node) {
3268   if (node->user_data == nullptr) {
3269     return kTfLiteError;
3270   }
3271 
3272   return static_cast<Subgraph*>(node->user_data)->Invoke(context);
3273 }
3274 
SubgraphFree(TfLiteContext * context,void * buffer)3275 void SubgraphFree(TfLiteContext* context, void* buffer) {
3276   if (buffer != nullptr) {
3277     delete static_cast<Subgraph*>(buffer);
3278   }
3279 }
3280 
3281 const TfLiteRegistration kSubgraphRegistration = {
3282     /*.init=*/SubgraphInit,
3283     /*.free=*/SubgraphFree,
3284     /*.prepare=*/SubgraphPrepare,
3285     /*.invoke=*/SubgraphInvoke,
3286     /*.profiling_string=*/nullptr,
3287     /*.builtin_code=*/0,
3288     /*.custom_name=*/"TfLiteXNNPackDelegate",
3289     /*.version=*/2,
3290 };
3291 
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)3292 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
3293   TfLiteIntArray* ops_to_replace =
3294       static_cast<::tflite::xnnpack::Delegate*>(delegate->data_)
3295           ->PrepareOpsToDelegate(context);
3296   if (ops_to_replace == nullptr) {
3297     return kTfLiteError;
3298   }
3299 
3300   const TfLiteStatus status = context->ReplaceNodeSubsetsWithDelegateKernels(
3301       context, kSubgraphRegistration, ops_to_replace, delegate);
3302   TfLiteIntArrayFree(ops_to_replace);
3303   return status;
3304 }
3305 
3306 }  // namespace
3307 }  // namespace xnnpack
3308 }  // namespace tflite
3309 
TfLiteXNNPackDelegateOptionsDefault()3310 TfLiteXNNPackDelegateOptions TfLiteXNNPackDelegateOptionsDefault() {
3311   TfLiteXNNPackDelegateOptions options = {0};
3312   return options;
3313 }
3314 
TfLiteXNNPackDelegateCreate(const TfLiteXNNPackDelegateOptions * options)3315 TfLiteDelegate* TfLiteXNNPackDelegateCreate(
3316     const TfLiteXNNPackDelegateOptions* options) {
3317   xnn_status status = xnn_initialize(/*allocator=*/nullptr);
3318   if (status != xnn_status_success) {
3319     return nullptr;
3320   }
3321 
3322   auto* xnnpack_delegate = new ::tflite::xnnpack::Delegate(options);
3323   return xnnpack_delegate ? xnnpack_delegate->tflite_delegate() : nullptr;
3324 }
3325 
TfLiteXNNPackDelegateDelete(TfLiteDelegate * delegate)3326 void TfLiteXNNPackDelegateDelete(TfLiteDelegate* delegate) {
3327   if (delegate != nullptr) {
3328     delete static_cast<::tflite::xnnpack::Delegate*>(delegate->data_);
3329   }
3330 }
3331