1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/core/subgraph.h"
17 
18 #include <stdarg.h>
19 #include <stddef.h>
20 
21 #include <algorithm>
22 #include <cstdint>
23 #include <cstdlib>
24 #include <cstring>
25 #include <iterator>
26 #include <memory>
27 #include <utility>
28 #include <vector>
29 
30 #include "tensorflow/lite/allocation.h"
31 #include "tensorflow/lite/arena_planner.h"
32 #include "tensorflow/lite/builtin_ops.h"
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/context_util.h"
35 #include "tensorflow/lite/core/api/error_reporter.h"
36 #include "tensorflow/lite/core/api/profiler.h"
37 #include "tensorflow/lite/core/api/tensor_utils.h"
38 #include "tensorflow/lite/core/macros.h"
39 #include "tensorflow/lite/experimental/resource/resource_base.h"
40 #include "tensorflow/lite/graph_info.h"
41 #include "tensorflow/lite/memory_planner.h"
42 #include "tensorflow/lite/minimal_logging.h"
43 #include "tensorflow/lite/schema/schema_generated.h"
44 #include "tensorflow/lite/util.h"
45 
46 namespace tflite {
47 
48 namespace {
49 
50 struct TfLiteQuantizationDeleter {
operator ()tflite::__anon0a8c24fa0111::TfLiteQuantizationDeleter51   void operator()(TfLiteQuantization* q) {
52     if (q) TfLiteQuantizationFree(q);
53   }
54 };
55 
56 using ScopedTfLiteQuantization =
57     std::unique_ptr<TfLiteQuantization, TfLiteQuantizationDeleter>;
58 
59 struct TfLiteSparsityDeleter {
operator ()tflite::__anon0a8c24fa0111::TfLiteSparsityDeleter60   void operator()(TfLiteSparsity* s) {
61     if (s) TfLiteSparsityFree(s);
62   }
63 };
64 
65 using ScopedTfLiteSparsity =
66     std::unique_ptr<TfLiteSparsity, TfLiteSparsityDeleter>;
67 
ReportOpError(TfLiteContext * context,const TfLiteNode & node,const TfLiteRegistration & registration,int node_index,const char * message)68 TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node,
69                            const TfLiteRegistration& registration,
70                            int node_index, const char* message) {
71   context->ReportError(
72       context, "Node number %d (%s) %s.\n", node_index,
73       registration.custom_name
74           ? registration.custom_name
75           : EnumNameBuiltinOperator(
76                 static_cast<BuiltinOperator>(registration.builtin_code)),
77       message);
78   return kTfLiteError;
79 }
80 
81 // Stub method which returns kTfLiteError when the function is forbidden.
82 // We're registering this function to several different function to save
83 // compiled binary size. Please note the restrictions:
84 // * The type of first parameter have to be `TfLiteContext*`.
85 // * All parameters must be trivially destructible. (E.g. No C++ class)
ForbiddenContextFunction(TfLiteContext * context,...)86 TfLiteStatus ForbiddenContextFunction(TfLiteContext* context, ...) {
87   context->ReportError(context,
88                        "The function is forbidden if not calling in delegate.");
89   return kTfLiteError;
90 }
91 
92 // Set the ForbiddenContextFunction to a compatible function pointer.
93 template <typename FunctionType>
SetForbiddenContextFunction(FunctionType * func)94 void SetForbiddenContextFunction(FunctionType* func) {
95   *func = reinterpret_cast<FunctionType>(ForbiddenContextFunction);
96 }
97 
98 // Returns true if at least one tensor in the given list is kTfLiteDynamic.
99 template <typename TensorIntArray>
HasDynamicTensorImpl(const TfLiteContext & context,const TensorIntArray & int_array)100 bool HasDynamicTensorImpl(const TfLiteContext& context,
101                           const TensorIntArray& int_array) {
102   for (int i : int_array) {
103     if (i == kTfLiteOptionalTensor) continue;
104     const TfLiteTensor& tensor = context.tensors[i];
105     if (tensor.allocation_type == kTfLiteDynamic) {
106       return true;
107     }
108   }
109   return false;
110 }
111 
HasDynamicTensor(const TfLiteContext & context,const TfLiteIntArray * int_array)112 bool HasDynamicTensor(const TfLiteContext& context,
113                       const TfLiteIntArray* int_array) {
114   return HasDynamicTensorImpl(context, TfLiteIntArrayView{int_array});
115 }
116 
117 // Gets the legacy TfLiteQuantizationParams from the current TfLiteQuantization.
GetLegacyQuantization(const TfLiteQuantization & quantization)118 TfLiteQuantizationParams GetLegacyQuantization(
119     const TfLiteQuantization& quantization) {
120   TfLiteQuantizationParams legacy_quantization;
121   legacy_quantization.scale = 0;
122   legacy_quantization.zero_point = 0;
123 
124   // If the quantization type isn't affine, return the empty
125   // legacy_quantization.
126   if (quantization.type != kTfLiteAffineQuantization) {
127     return legacy_quantization;
128   }
129 
130   auto* affine_quantization =
131       static_cast<TfLiteAffineQuantization*>(quantization.params);
132   if (!affine_quantization || !affine_quantization->scale ||
133       !affine_quantization->zero_point ||
134       affine_quantization->scale->size != 1 ||
135       affine_quantization->zero_point->size != 1) {
136     return legacy_quantization;
137   }
138 
139   // We know its per-layer quantization now.
140   legacy_quantization.scale = affine_quantization->scale->data[0];
141   legacy_quantization.zero_point = affine_quantization->zero_point->data[0];
142   return legacy_quantization;
143 }
144 
145 static constexpr const char kUnknownCustomOpName[] = "UnknownCustomOp";
GetTFLiteOpName(const TfLiteRegistration & op_reg)146 const char* GetTFLiteOpName(const TfLiteRegistration& op_reg) {
147   if (op_reg.builtin_code == tflite::BuiltinOperator_CUSTOM) {
148     const char* const custom_name = op_reg.custom_name;
149     return custom_name ? custom_name : kUnknownCustomOpName;
150   }
151   if (op_reg.builtin_code == tflite::BuiltinOperator_DELEGATE &&
152       op_reg.custom_name) {
153     return op_reg.custom_name;
154   }
155   return tflite::EnumNamesBuiltinOperator()[op_reg.builtin_code];
156 }
157 
ValidateCustomAllocationForTensor(TfLiteContext * context,const TfLiteTensor * tensor,const TfLiteCustomAllocation & allocation)158 TfLiteStatus ValidateCustomAllocationForTensor(
159     TfLiteContext* context, const TfLiteTensor* tensor,
160     const TfLiteCustomAllocation& allocation) {
161   TF_LITE_ENSURE(context, allocation.data != nullptr);
162   TF_LITE_ENSURE(context, allocation.bytes >= tensor->bytes);
163   // Ensure provided memory is aligned to what TFLite requires.
164   const intptr_t data_ptr_value = reinterpret_cast<intptr_t>(allocation.data);
165   TF_LITE_ENSURE(context, data_ptr_value % kDefaultTensorAlignment == 0);
166   return kTfLiteOk;
167 }
168 
169 }  // namespace
170 
171 // A trivial implementation of GraphInfo around the Interpreter.
172 // NOTE: this interpreter info represents the subset of the
173 // graph that is executed according to execution plan. Thus,
174 // the indices are execution plan indices rather than raw node
175 // indices.
176 class InterpreterInfo : public GraphInfo {
177  public:
InterpreterInfo(Subgraph * subgraph)178   explicit InterpreterInfo(Subgraph* subgraph) : subgraph_(subgraph) {}
179 
num_tensors() const180   size_t num_tensors() const override { return subgraph_->tensors_size(); }
tensor(size_t index)181   TfLiteTensor* tensor(size_t index) override {
182     return subgraph_->tensor(index);
183   }
num_execution_nodes() const184   size_t num_execution_nodes() const override {
185     return subgraph_->execution_plan().size();
186   }
num_total_nodes() const187   size_t num_total_nodes() const override { return subgraph_->nodes_size(); }
node(size_t index) const188   const TfLiteNode& node(size_t index) const override {
189     int node_index = subgraph_->execution_plan()[index];
190     return subgraph_->nodes_and_registration()[node_index].first;
191   }
node_index(size_t index) const192   size_t node_index(size_t index) const override {
193     return subgraph_->execution_plan()[index];
194   }
inputs() const195   const std::vector<int>& inputs() const override {
196     return subgraph_->inputs();
197   }
outputs() const198   const std::vector<int>& outputs() const override {
199     return subgraph_->outputs();
200   }
variables() const201   const std::vector<int>& variables() const override {
202     return subgraph_->variables();
203   }
204 
205  public:
206   Subgraph* subgraph_;
207 };
208 
Subgraph(ErrorReporter * error_reporter,TfLiteExternalContext ** external_contexts,std::vector<std::unique_ptr<Subgraph>> * subgraphs,resource::ResourceMap * resources)209 Subgraph::Subgraph(ErrorReporter* error_reporter,
210                    TfLiteExternalContext** external_contexts,
211                    std::vector<std::unique_ptr<Subgraph>>* subgraphs,
212                    resource::ResourceMap* resources)
213     : external_contexts_(external_contexts),
214       error_reporter_(error_reporter),
215       next_execution_plan_index_to_prepare_(0),
216       next_execution_plan_index_to_plan_allocation_(0),
217       subgraphs_(subgraphs),
218       resources_(resources) {
219   // TODO(b/161272052): Consider a better TfLiteContext initialization pattern:
220   context_.impl_ = static_cast<void*>(this);
221   context_.ResizeTensor = ResizeTensor;
222   context_.ReportError = ReportErrorC;
223   context_.AddTensors = AddTensors;
224   context_.tensors = nullptr;
225   context_.tensors_size = 0;
226   context_.allow_fp32_relax_to_fp16 = false;
227   context_.recommended_num_threads = -1;
228   context_.GetExternalContext = GetExternalContext;
229   context_.SetExternalContext = SetExternalContext;
230   context_.profiler = nullptr;
231   context_.GetTensor = nullptr;
232   context_.GetEvalTensor = nullptr;
233 
234   // Reserve some space for the tensors to avoid excessive resizing.
235   tensors_.reserve(kTensorsReservedCapacity);
236   nodes_and_registration_.reserve(kTensorsReservedCapacity);
237   // Invalid to call these except from TfLiteDelegate
238   SwitchToKernelContext();
239 }
240 
~Subgraph()241 Subgraph::~Subgraph() {
242   for (int node_index = 0; node_index < nodes_and_registration_.size();
243        ++node_index) {
244     CleanupNode(node_index);
245   }
246 
247   for (size_t i = 0; i < context_.tensors_size; i++) {
248     TfLiteTensor* tensor = &context_.tensors[i];
249     if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
250         tensor->delegate->FreeBufferHandle != nullptr) {
251       tensor->delegate->FreeBufferHandle(&context_, tensor->delegate,
252                                          &tensor->buffer_handle);
253     }
254     TfLiteTensorFree(tensor);
255   }
256 }
257 
CleanupNode(int node_index)258 void Subgraph::CleanupNode(int node_index) {
259   TfLiteNode& node = nodes_and_registration_[node_index].first;
260   const TfLiteRegistration& registration =
261       nodes_and_registration_[node_index].second;
262   TfLiteIntArrayFree(node.inputs);
263   TfLiteIntArrayFree(node.outputs);
264   TfLiteIntArrayFree(node.temporaries);
265   TfLiteIntArrayFree(node.intermediates);
266   if (node.builtin_data) free(node.builtin_data);
267   OpFree(registration, node.user_data);
268   node.builtin_data = nullptr;
269 }
270 
ReplaceNodeSubsetsWithDelegateKernels(TfLiteContext * context,TfLiteRegistration registration,const TfLiteIntArray * nodes_to_replace,TfLiteDelegate * delegate)271 TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
272     TfLiteContext* context, TfLiteRegistration registration,
273     const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) {
274   return static_cast<Subgraph*>(context->impl_)
275       ->ReplaceNodeSubsetsWithDelegateKernels(registration, nodes_to_replace,
276                                               delegate);
277 }
278 
279 namespace {
280 
281 // Copy a std::vector<int> to an existing TfLiteIntArray.
282 // This is a low-level data manipulation function, and it's caller's
283 // responsibility to ensure TfLiteIntArray has enough size.
CopyVectorToTfLiteIntArray(const std::vector<int> & vec,TfLiteIntArray * arr)284 void CopyVectorToTfLiteIntArray(const std::vector<int>& vec,
285                                 TfLiteIntArray* arr) {
286   arr->size = vec.size();
287   memcpy(arr->data, vec.data(), sizeof(int) * arr->size);
288 }
289 
290 // This function allocates a continuous memory space that contains a
291 // TfLiteDelegateParams followed by a several TfLiteIntArray.
292 // When calling `free` at TfLiteDelegateParams*, all the allocated space
293 // will be freed together.
294 //
295 // +-----------------------------------+
296 // | TfLiteDelegateParams              |
297 // | TfLiteDelegate* delegate;         |
298 // | TfLiteIntArray* nodes_to_replace; |--\
299 // | TfLiteIntArray* input_tensors;    |--+--\
300 // | TfLiteIntArray* output_tensors;   |--+--+--\
301 // +-----------------------------------+  |  |  |
302 // | TfLiteIntArray (variable size)    |<-/  |  |
303 // +-----------------------------------+     |  |
304 // | TfLiteIntArray (variable size)    |<----/  |
305 // +-----------------------------------+        |
306 // | TfLiteIntArray (variable size)    |<-------/
307 // +-----------------------------------+
CreateDelegateParams(TfLiteDelegate * delegate,const NodeSubset & node_subset)308 TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate,
309                                            const NodeSubset& node_subset) {
310   // Step 1: Calculate the allocation size.
311   int allocation_size = sizeof(TfLiteDelegateParams);
312 
313   int nodes_to_replace_size =
314       TfLiteIntArrayGetSizeInBytes(node_subset.nodes.size());
315   allocation_size += nodes_to_replace_size;
316 
317   int input_tensors_size =
318       TfLiteIntArrayGetSizeInBytes(node_subset.input_tensors.size());
319   allocation_size += input_tensors_size;
320 
321   int output_tensors_size =
322       TfLiteIntArrayGetSizeInBytes(node_subset.output_tensors.size());
323   allocation_size += output_tensors_size;
324 
325   // Step 2: Allocate the memory.
326   // Use `char*` for conveniently step through the allocated space by bytes.
327   char* allocation = static_cast<char*>(malloc(allocation_size));
328 
329   // Step 3: Fill all data structures.
330   TfLiteDelegateParams* params =
331       reinterpret_cast<TfLiteDelegateParams*>(allocation);
332   params->delegate = delegate;
333   allocation += sizeof(TfLiteDelegateParams);
334 
335   params->nodes_to_replace = reinterpret_cast<TfLiteIntArray*>(allocation);
336   CopyVectorToTfLiteIntArray(node_subset.nodes, params->nodes_to_replace);
337   allocation += nodes_to_replace_size;
338 
339   params->input_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
340   CopyVectorToTfLiteIntArray(node_subset.input_tensors, params->input_tensors);
341   allocation += input_tensors_size;
342 
343   params->output_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
344   CopyVectorToTfLiteIntArray(node_subset.output_tensors,
345                              params->output_tensors);
346   allocation += output_tensors_size;
347 
348   return params;
349 }
350 
351 // Assumes that params is not nullptr.
PopulatePreviewDelegateParams(const NodeSubset & node_subset,TfLiteDelegateParams * params)352 void PopulatePreviewDelegateParams(const NodeSubset& node_subset,
353                                    TfLiteDelegateParams* params) {
354   // Since these params are used for previewing partitioning, params->delegate
355   // is not required.
356   params->delegate = nullptr;
357 
358   params->nodes_to_replace = TfLiteIntArrayCreate(node_subset.nodes.size());
359   CopyVectorToTfLiteIntArray(node_subset.nodes, params->nodes_to_replace);
360 
361   params->input_tensors =
362       TfLiteIntArrayCreate(node_subset.input_tensors.size());
363   CopyVectorToTfLiteIntArray(node_subset.input_tensors, params->input_tensors);
364 
365   params->output_tensors =
366       TfLiteIntArrayCreate(node_subset.output_tensors.size());
367   CopyVectorToTfLiteIntArray(node_subset.output_tensors,
368                              params->output_tensors);
369 }
370 
371 }  // namespace
372 
ReplaceNodeSubsetsWithDelegateKernels(TfLiteRegistration registration,const TfLiteIntArray * nodes_to_replace,TfLiteDelegate * delegate)373 TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
374     TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
375     TfLiteDelegate* delegate) {
376   // Ignore empty node replacement sets.
377   if (!nodes_to_replace->size) {
378     return kTfLiteOk;
379   }
380 
381   // Annotate the registration as DELEGATE op.
382   registration.builtin_code = BuiltinOperator_DELEGATE;
383 
384   // Analyze the graph to find all independent node_subsets that are either
385   // fully not-this-delegate or this-delegate computation.
386   InterpreterInfo info(this);
387   std::vector<NodeSubset> node_subsets;
388   PartitionGraphIntoIndependentNodeSubsets(&info, nodes_to_replace,
389                                            &node_subsets);
390 
391   TFLITE_LOG(
392       tflite::TFLITE_LOG_INFO,
393       "Replacing %d node(s) with delegate (%s) node, yielding %zu partitions.",
394       nodes_to_replace->size,
395       registration.custom_name ? registration.custom_name : "unknown",
396       node_subsets.size());
397 
398   execution_plan_.clear();
399 
400   for (auto& node_subset : node_subsets) {
401     // Subsets claimed by the delegate should have a "macro" op created, the
402     // other node_subsets (kTfNonPartition) just have their nodes added back to
403     // the execution plan.
404     switch (node_subset.type) {
405       case NodeSubset::kTfNonPartition:
406         for (auto it = node_subset.nodes.begin(); it != node_subset.nodes.end();
407              ++it) {
408           execution_plan_.push_back(*it);
409         }
410         break;
411       case NodeSubset::kTfPartition: {
412         int node_index;
413 
414         TfLiteDelegateParams* params =
415             CreateDelegateParams(delegate, node_subset);
416         TF_LITE_ENSURE_STATUS(AddNodeWithParameters(
417             node_subset.input_tensors, node_subset.output_tensors, {}, nullptr,
418             0, params, &registration, &node_index));
419 
420         // Initialize the output tensors's delegate-related fields.
421         for (int tensor_index : node_subset.output_tensors) {
422           TfLiteTensor* tensor = &tensors_[tensor_index];
423           TF_LITE_ENSURE(&context_, tensor->delegate == nullptr ||
424                                         tensor->delegate == delegate);
425           tensor->delegate = delegate;
426         }
427 
428         // Associate the node with the delegate.
429         TfLiteNode* node = &nodes_and_registration_[node_index].first;
430         node->delegate = delegate;
431       } break;
432       case NodeSubset::kTfUnexplored:
433         return kTfLiteError;
434         break;
435     }
436   }
437   return kTfLiteOk;
438 }
439 
GetExternalContext(TfLiteExternalContextType type)440 TfLiteExternalContext* Subgraph::GetExternalContext(
441     TfLiteExternalContextType type) {
442   if (static_cast<int>(type) >= 0 && type < kTfLiteMaxExternalContexts) {
443     return external_contexts_[type];
444   }
445   return nullptr;
446 }
447 
GetExternalContext(struct TfLiteContext * context,TfLiteExternalContextType type)448 TfLiteExternalContext* Subgraph::GetExternalContext(
449     struct TfLiteContext* context, TfLiteExternalContextType type) {
450   return static_cast<Subgraph*>(context->impl_)->GetExternalContext(type);
451 }
452 
SetExternalContext(TfLiteExternalContextType type,TfLiteExternalContext * ctx)453 void Subgraph::SetExternalContext(TfLiteExternalContextType type,
454                                   TfLiteExternalContext* ctx) {
455   if (static_cast<int>(type) >= 0 && type < kTfLiteMaxExternalContexts) {
456     external_contexts_[type] = ctx;
457   }
458 }
459 
SetExternalContext(struct TfLiteContext * context,TfLiteExternalContextType type,TfLiteExternalContext * ctx)460 void Subgraph::SetExternalContext(struct TfLiteContext* context,
461                                   TfLiteExternalContextType type,
462                                   TfLiteExternalContext* ctx) {
463   return static_cast<Subgraph*>(context->impl_)->SetExternalContext(type, ctx);
464 }
465 
466 // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
467 // this memory and it is only guaranteed to exist during the invocation of the
468 // delegate prepare.
GetExecutionPlan(TfLiteIntArray ** execution_plan)469 TfLiteStatus Subgraph::GetExecutionPlan(TfLiteIntArray** execution_plan) {
470   // TODO(aselle): Do not make a copy here
471   plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size()));
472   *execution_plan = plan_cache_.get();
473   static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
474                 "TfLiteIntArray and execution_plan do not contain same type.");
475   std::memcpy(plan_cache_->data, execution_plan_.data(),
476               sizeof(plan_cache_->data[0]) * execution_plan_.size());
477   return kTfLiteOk;
478 }
479 
480 // WARNING: This is an experimental interface that is subject to change.
481 // Entry point for C node plugin API to get the execution plan
GetExecutionPlan(struct TfLiteContext * context,TfLiteIntArray ** execution_plan)482 TfLiteStatus Subgraph::GetExecutionPlan(struct TfLiteContext* context,
483                                         TfLiteIntArray** execution_plan) {
484   return static_cast<Subgraph*>(context->impl_)
485       ->GetExecutionPlan(execution_plan);
486 }
487 
FreeDelegatePartitioningData()488 void Subgraph::FreeDelegatePartitioningData() {
489   for (auto& params : partitioning_preview_cache_) {
490     TfLiteIntArrayFree(params.nodes_to_replace);
491     TfLiteIntArrayFree(params.input_tensors);
492     TfLiteIntArrayFree(params.output_tensors);
493   }
494   partitioning_preview_cache_.clear();
495 }
496 
PreviewDelegatePartitioning(const TfLiteIntArray * nodes_to_replace,TfLiteDelegateParams ** partition_params_array,int * num_partitions)497 TfLiteStatus Subgraph::PreviewDelegatePartitioning(
498     const TfLiteIntArray* nodes_to_replace,
499     TfLiteDelegateParams** partition_params_array, int* num_partitions) {
500   // Ensure partitioning cache is empty.
501   FreeDelegatePartitioningData();
502   // Defaults.
503   if (!partition_params_array || !num_partitions) return kTfLiteError;
504   *partition_params_array = nullptr;
505   *num_partitions = 0;
506   if (!nodes_to_replace->size) {
507     return kTfLiteOk;
508   }
509 
510   // Partition the execution plan into node subsets.
511   InterpreterInfo info(this);
512   std::vector<NodeSubset> node_subsets;
513   PartitionGraphIntoIndependentNodeSubsets(&info, nodes_to_replace,
514                                            &node_subsets);
515 
516   // Create one TfLiteDelegateParams per node-subset which would be delegated.
517   for (auto& node_subset : node_subsets) {
518     if (node_subset.type != NodeSubset::kTfPartition) {
519       continue;
520     }
521     partitioning_preview_cache_.emplace_back();
522     PopulatePreviewDelegateParams(node_subset,
523                                   &partitioning_preview_cache_.back());
524     ++*num_partitions;
525   }
526 
527   *partition_params_array = partitioning_preview_cache_.data();
528   return kTfLiteOk;
529 }
530 
PreviewDelegatePartitioning(struct TfLiteContext * context,const TfLiteIntArray * nodes_to_replace,TfLiteDelegateParams ** partition_params_array,int * num_partitions)531 TfLiteStatus Subgraph::PreviewDelegatePartitioning(
532     struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
533     TfLiteDelegateParams** partition_params_array, int* num_partitions) {
534   return static_cast<Subgraph*>(context->impl_)
535       ->PreviewDelegatePartitioning(nodes_to_replace, partition_params_array,
536                                     num_partitions);
537 }
538 
SetInputs(std::vector<int> inputs)539 TfLiteStatus Subgraph::SetInputs(std::vector<int> inputs) {
540   TF_LITE_ENSURE_OK(&context_,
541                     CheckTensorIndices("inputs", inputs.data(), inputs.size()));
542   inputs_ = std::move(inputs);
543   return kTfLiteOk;
544 }
545 
SetOutputs(std::vector<int> outputs)546 TfLiteStatus Subgraph::SetOutputs(std::vector<int> outputs) {
547   TF_LITE_ENSURE_OK(
548       &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
549   outputs_ = std::move(outputs);
550   return kTfLiteOk;
551 }
552 
SetVariables(std::vector<int> variables)553 TfLiteStatus Subgraph::SetVariables(std::vector<int> variables) {
554   TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(),
555                                                   variables.size()));
556   variables_ = std::move(variables);
557   return kTfLiteOk;
558 }
559 
SetCancellationFunction(void * data,bool (* check_cancelled_func)(void *))560 void Subgraph::SetCancellationFunction(void* data,
561                                        bool (*check_cancelled_func)(void*)) {
562   cancellation_data_ = data;
563   check_cancelled_func_ = check_cancelled_func;
564 }
565 
IsCancelled()566 bool Subgraph::IsCancelled() {
567   return (check_cancelled_func_ != nullptr) &&
568          (*check_cancelled_func_)(cancellation_data_);
569 }
570 
ReserveNodes(int count)571 void Subgraph::ReserveNodes(int count) {
572   nodes_and_registration_.reserve(count);
573 }
574 
CheckTensorIndices(const char * label,const int * indices,int length)575 TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices,
576                                           int length) {
577   // Making sure kTfLiteOptionalTensor is not re-defined to something other than
578   // -1.
579   static_assert(kTfLiteOptionalTensor == -1,
580                 "kTfLiteOptionalTensor should be defined -1");
581 
582   for (int i = 0; i < length; i++) {
583     int index = indices[i];
584     // Continue if index == kTfLiteOptionalTensor before additional comparisons
585     // below, size_t(-1) is always >= context_tensors_size.
586     if (index == kTfLiteOptionalTensor) {
587       continue;
588     }
589     if (index < 0 || static_cast<size_t>(index) >= context_.tensors_size) {
590       ReportError(
591           "Invalid tensor index %d in %s. The subgraph has %d tensors\n", index,
592           label, context_.tensors_size);
593       consistent_ = false;
594       return kTfLiteError;
595     }
596   }
597   return kTfLiteOk;
598 }
599 
600 // We have two arrays and we need to check that elements from one array don't
601 // show up in the other. We could sort both arrays and then iterate with two
602 // pointers from start to finish always increasing the smaller one but since
603 // these arrays are usually short (<25 elements for inputs, usually <3 for
604 // outputs), this might be slower than the naive approach (if arrays have size n
605 // and m, with n >> m ~ O(1), first approach is O(nlogn) whereas the other is
606 // O(n)). Plus, sorting the input and output arrays might not be something we
607 // want as it destroys ordering of elements.
608 //
609 // If it turns out that this is an issue, we can switch to the other algorithm.
CheckInputAndOutputForOverlap(const int * input_indices,int num_inputs,const int * output_indices,int num_outputs)610 TfLiteStatus Subgraph::CheckInputAndOutputForOverlap(const int* input_indices,
611                                                      int num_inputs,
612                                                      const int* output_indices,
613                                                      int num_outputs) {
614   for (int i = 0; i < num_inputs; i++) {
615     for (int j = 0; j < num_outputs; j++) {
616       if (input_indices[i] == output_indices[j]) {
617         ReportError("Tensor %d is both input %d and output %d\n",
618                     input_indices[i], i, j);
619         consistent_ = false;
620         return kTfLiteError;
621       }
622     }
623   }
624   return kTfLiteOk;
625 }
626 
627 namespace {
628 // Multiply two sizes and return true if overflow occurred;
629 // This is based off tensorflow/overflow.h but is simpler as we already
630 // have unsigned numbers. It is also generalized to work where sizeof(size_t)
631 // is not 8.
MultiplyAndCheckOverflow(size_t a,size_t b,size_t * product)632 TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) {
633   // Multiplying a * b where a and b are size_t cannot result in overflow in a
634   // size_t accumulator if both numbers have no non-zero bits in their upper
635   // half.
636   constexpr size_t size_t_bits = 8 * sizeof(size_t);
637   constexpr size_t overflow_upper_half_bit_position = size_t_bits / 2;
638   *product = a * b;
639   // If neither integers have non-zero bits past 32 bits can't overflow.
640   // Otherwise check using slow devision.
641   if (TFLITE_EXPECT_FALSE((a | b) >> overflow_upper_half_bit_position != 0)) {
642     if (a != 0 && *product / a != b) return kTfLiteError;
643   }
644   return kTfLiteOk;
645 }
646 }  // namespace
647 
BytesRequired(TfLiteType type,const int * dims,size_t dims_size,size_t * bytes)648 TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,
649                                      size_t dims_size, size_t* bytes) {
650   TF_LITE_ENSURE(&context_, bytes != nullptr);
651   size_t count = 1;
652   for (int k = 0; k < dims_size; k++) {
653     size_t old_count = count;
654     TF_LITE_ENSURE_MSG(
655         &context_,
656         MultiplyAndCheckOverflow(old_count, dims[k], &count) == kTfLiteOk,
657         "BytesRequired number of elements overflowed.\n");
658   }
659   size_t type_size = 0;
660   TF_LITE_ENSURE_OK(&context_, GetSizeOfType(&context_, type, &type_size));
661   TF_LITE_ENSURE_MSG(
662       &context_, MultiplyAndCheckOverflow(type_size, count, bytes) == kTfLiteOk,
663       "BytesRequired number of bytes overflowed.\n");
664   return kTfLiteOk;
665 }
666 
AllocateTensors()667 TfLiteStatus Subgraph::AllocateTensors() {
668   TFLITE_SCOPED_TAGGED_DEFAULT_PROFILE(profiler_.get(), "AllocateTensors");
669   if (!consistent_) {
670     ReportError("AllocateTensors() called on inconsistent model.");
671     return kTfLiteError;
672   }
673 
674   // Restore delegation state if applicable.
675   TF_LITE_ENSURE_STATUS(RedoAllDelegates());
676 
677   // Explicit (re)allocation is necessary if nodes have been changed or tensors
678   // have been resized. For inputs marked as dynamic, we can't short-circuit the
679   // allocation as the client may have done the resize manually.
680   if (state_ != kStateUninvokable &&
681       !HasDynamicTensorImpl(context_, inputs())) {
682     if (memory_planner_ && !memory_planner_->HasNonPersistentMemory()) {
683       // If the only change was the release of non-persistent memory via
684       // ReleaseNonPersistentMemory(), just re-allocate it. For any other type
685       // of memory-planning change (for eg, ResizeInputTensor), the state would
686       // be kStateUninvokable.
687       memory_planner_->AcquireNonPersistentMemory();
688     }
689     return kTfLiteOk;
690   }
691 
692   next_execution_plan_index_to_prepare_ = 0;
693   next_execution_plan_index_to_plan_allocation_ = 0;
694   next_original_execution_plan_index_to_prepare_ = 0;
695   if (memory_planner_) {
696     TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
697   }
698 
699   TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
700 
701   state_ = kStateInvokable;
702 
703   // Reset the variable tensors to zero after (re)allocating the tensors.
704   // Developers shouldn't rely on the side effect of this function to reset
705   // variable tensors. They should call `ResetVariableTensors` directly
706   // instead.
707   ResetVariableTensors();
708 
709   return kTfLiteOk;
710 }
711 
712 // TODO(ycling): Support non-zero default values.
ResetVariableTensors()713 TfLiteStatus Subgraph::ResetVariableTensors() {
714   for (auto& tensor : tensors_) {
715     if (!tensor.is_variable) {
716       continue;
717     }
718 
719     if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
720       // If variable tensors allocation type is `kTfLiteArenaRwPersistent`, then
721       // they must be allocated after the initial `PrepareOpsAndTensors()` is
722       // called.
723       TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr);
724       tflite::ResetVariableTensor(&tensor);
725     } else {
726       // If variable tensors allocation type is not `kTfLiteArenaRwPersistent`,
727       // then it can only be `kTfLiteCustom` in which case, we do not reset it.
728       TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, kTfLiteCustom);
729     }
730   }
731   return kTfLiteOk;
732 }
733 
AddNodeWithParameters(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & intermediates,const char * init_data,size_t init_data_size,void * builtin_data,const TfLiteRegistration * registration,int * node_index)734 TfLiteStatus Subgraph::AddNodeWithParameters(
735     const std::vector<int>& inputs, const std::vector<int>& outputs,
736     const std::vector<int>& intermediates, const char* init_data,
737     size_t init_data_size, void* builtin_data,
738     const TfLiteRegistration* registration, int* node_index) {
739   std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
740                                                               free);
741   if (state_ == kStateInvokableAndImmutable) {
742     ReportError("AddNodeWithParameters is disallowed when graph is immutable.");
743     return kTfLiteError;
744   }
745   state_ = kStateUninvokable;
746 
747   TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(),
748                                                   inputs.size()));
749   TF_LITE_ENSURE_OK(
750       &context_,
751       CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
752 
753   // For builtin ops, inputs and outputs must not overlap. Custom ops must do
754   // this check by themselves if they don't support overlapping tensors. This
755   // distinction is to allow custom ops to just forward a tensor, reusing it as
756   // both input and output.
757   if (builtin_data != nullptr) {
758     TF_LITE_ENSURE_OK(&context_, CheckInputAndOutputForOverlap(
759                                      inputs.data(), inputs.size(),
760                                      outputs.data(), outputs.size()));
761   }
762 
763   int new_node_index = nodes_and_registration_.size();
764   if (node_index) *node_index = new_node_index;
765   nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
766   auto& node_and_reg = nodes_and_registration_.back();
767   TfLiteNode& node = node_and_reg.first;
768   if (node.inputs) TfLiteIntArrayFree(node.inputs);
769   if (node.outputs) TfLiteIntArrayFree(node.outputs);
770   if (node.intermediates) TfLiteIntArrayFree(node.intermediates);
771   if (node.temporaries) TfLiteIntArrayFree(node.temporaries);
772 
773   // NOTE, here we are not using move semantics yet, since our internal
774   // representation isn't std::vector, but in the future we would like to avoid
775   // copies, so we want the interface to take r-value references now.
776   node.inputs = ConvertVectorToTfLiteIntArray(inputs);
777   node.outputs = ConvertVectorToTfLiteIntArray(outputs);
778   node.intermediates = ConvertVectorToTfLiteIntArray(intermediates);
779   node.temporaries = TfLiteIntArrayCreate(0);
780   if (init_data) {
781     node.user_data = OpInit(*registration, init_data, init_data_size);
782   } else {
783     node.user_data = OpInit(
784         *registration, static_cast<const char*>(builtin_data_deleter.get()), 0);
785   }
786 
787   node.builtin_data = builtin_data_deleter.release();
788   // TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size`
789   // properly for nodes generated by ReplaceNodeSubsetsWithDelegateKernels.
790 
791   if (registration->builtin_code == BuiltinOperator_CUSTOM) {
792     // When it's a CUSTOM op, the `custom_options` field in the Flatbuffer
793     // `Operator` table is passed in.
794     node.custom_initial_data = init_data;
795     node.custom_initial_data_size = init_data_size;
796   } else {
797     node.custom_initial_data = nullptr;
798     node.custom_initial_data_size = 0;
799   }
800 
801   node.delegate = nullptr;
802   // Copying of registration is required to support unresolved custom ops.
803   node_and_reg.second = *registration;
804   execution_plan_.push_back(new_node_index);
805   return kTfLiteOk;
806 }
807 
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)808 TfLiteStatus Subgraph::ResizeInputTensor(int tensor_index,
809                                          const std::vector<int>& dims) {
810   const bool delegates_applied = !pre_delegation_execution_plan_.empty();
811   const bool graph_is_immutable = state_ == kStateInvokableAndImmutable;
812   if (graph_is_immutable && !delegates_applied) {
813     ReportError("ResizeInputTensor is disallowed when graph is immutable.");
814     return kTfLiteError;
815   }
816 
817   // TODO(aselle): All bounds checks can be implemented as one-sided bounds
818   // checks by casting to unsigned for efficiency. Profile before doing this.
819   TF_LITE_ENSURE(&context_,
820                  tensor_index < context_.tensors_size && tensor_index >= 0);
821   TfLiteTensor* tensor = &context_.tensors[tensor_index];
822 
823   // Short-circuit the state change if the dimensions don't change, avoiding
824   // unnecessary (re)allocations.
825   //
826   // Note that it's required to check `tensor->data.raw != nullptr`. Otherwise
827   // the subgraph won't allocate memory for a dynamic tensor when its size
828   // is equal to the original tensor size.
829   if (tensor->data.raw != nullptr &&
830       EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) {
831     return kTfLiteOk;
832   }
833 
834   if (graph_is_immutable) {
835     // Undo delegation if it resulted in the graph being immutable.
836     TF_LITE_ENSURE_STATUS(UndoAllDelegates());
837   }
838   state_ = kStateUninvokable;
839   return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims));
840 }
841 
ResizeInputTensorStrict(int tensor_index,const std::vector<int> & dims)842 TfLiteStatus Subgraph::ResizeInputTensorStrict(int tensor_index,
843                                                const std::vector<int>& dims) {
844   TF_LITE_ENSURE(&context_,
845                  tensor_index < context_.tensors_size && tensor_index >= 0);
846   TfLiteTensor* tensor = &context_.tensors[tensor_index];
847 
848   // Ensure that only unknown dimensions can be resized.
849   TF_LITE_ENSURE_EQ(&context_, tensor->dims->size, dims.size());
850   for (size_t idx = 0; idx < dims.size(); idx++) {
851     // `dims_signature` is not defined when no unknown dimensions are present.
852     int dim_signature;
853     if (tensor->dims_signature && tensor->dims_signature->size) {
854       dim_signature = tensor->dims_signature->data[idx];
855     } else {
856       dim_signature = tensor->dims->data[idx];
857     }
858 
859     if (dim_signature != -1 && dim_signature != dims[idx]) {
860       ReportError(
861           "Attempting to resize dimension %d of tensor %d with value %d to %d. "
862           "ResizeInputTensorStrict only allows mutating unknown dimensions "
863           "identified by -1.",
864           idx, tensor_index, dim_signature, dims[idx]);
865       return kTfLiteError;
866     }
867   }
868 
869   return ResizeInputTensor(tensor_index, dims);
870 }
871 
ReleaseNonPersistentMemory()872 TfLiteStatus Subgraph::ReleaseNonPersistentMemory() {
873   if (memory_planner_) {
874     TF_LITE_ENSURE_STATUS(memory_planner_->ReleaseNonPersistentMemory());
875   }
876   return kTfLiteOk;
877 }
878 
OpPrepare(const TfLiteRegistration & op_reg,TfLiteNode * node)879 TfLiteStatus Subgraph::OpPrepare(const TfLiteRegistration& op_reg,
880                                  TfLiteNode* node) {
881   if (op_reg.prepare == nullptr) {
882     // Check if it's an unresolved custom op.
883     if (IsUnresolvedCustomOp(op_reg)) {
884       if (IsFlexOp(op_reg.custom_name)) {
885         ReportError(
886             "Regular TensorFlow ops are not supported by this interpreter. "
887             "Make sure you apply/link the Flex delegate before inference.");
888       } else {
889         ReportError("Encountered unresolved custom op: %s.",
890                     op_reg.custom_name ? op_reg.custom_name : "UnknownOp");
891       }
892       return kTfLiteError;
893     }
894     // Resolved ops can have a null Prepare function.
895     return kTfLiteOk;
896   }
897   return op_reg.prepare(&context_, node);
898 }
899 
PrepareOpsStartingAt(int first_execution_plan_index,const std::vector<int> & execution_plan,int * last_execution_plan_index_prepared)900 TfLiteStatus Subgraph::PrepareOpsStartingAt(
901     int first_execution_plan_index, const std::vector<int>& execution_plan,
902     int* last_execution_plan_index_prepared) {
903   if (first_execution_plan_index == 0) {
904     // Forwarding inputs without modification won't be not evaluated in the
905     // operators. So, it needs to look up the subgraph's output tensors at the
906     // beginning.
907     has_dynamic_tensors_ = HasDynamicTensorImpl(context_, outputs());
908   }
909   for (int execution_plan_index = first_execution_plan_index;
910        execution_plan_index < execution_plan.size(); execution_plan_index++) {
911     int node_index = execution_plan[execution_plan_index];
912     TfLiteNode& node = nodes_and_registration_[node_index].first;
913     const TfLiteRegistration& registration =
914         nodes_and_registration_[node_index].second;
915     EnsureTensorsVectorCapacity();
916     if (OpPrepare(registration, &node) != kTfLiteOk) {
917       return ReportOpError(&context_, node, registration, node_index,
918                            "failed to prepare");
919     }
920 
921     *last_execution_plan_index_prepared = execution_plan_index;
922 
923     // Discontinue if the node has dynamic outputs. Note that we don't
924     // stop for dynamic temporary tensors since they won't affect the
925     // sizes of other tensors in the graph.
926     if (HasDynamicTensor(context_, node.outputs)) {
927       has_dynamic_tensors_ = true;
928       return kTfLiteOk;
929     }
930   }
931   return kTfLiteOk;
932 }
933 
PrepareOpsAndTensors()934 TfLiteStatus Subgraph::PrepareOpsAndTensors() {
935   if (!memory_planner_) {
936     memory_planner_.reset(new ArenaPlanner(
937         &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)),
938         /*preserve_inputs=*/true, /*preserve_intermediates*/ false,
939         kDefaultTensorAlignment));
940     memory_planner_->PlanAllocations();
941   }
942 
943   // Prepare original execution plan if any applied delegate wants it.
944   // If any of the delegates is immutable, this won't be triggered
945   // post-delegation (since we undo/redo delegation). For all other cases, other
946   // delegates that do shape propagation themselves would still be able to.
947   bool prepare_original_plan = false;
948   if (!pre_delegation_execution_plan_.empty()) {
949     for (int i = 0; i < delegates_applied_.size(); ++i) {
950       if ((delegates_applied_[i]->flags &
951            kTfLiteDelegateFlagsRequirePropagatedShapes)) {
952         prepare_original_plan = true;
953         break;
954       }
955     }
956   }
957   if (prepare_original_plan) {
958     int last_original_exec_plan_index_prepared = 0;
959     TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt(
960         next_execution_plan_index_to_prepare_, pre_delegation_execution_plan_,
961         &last_original_exec_plan_index_prepared));
962     next_original_execution_plan_index_to_prepare_ =
963         last_original_exec_plan_index_prepared + 1;
964   }
965 
966   int last_exec_plan_index_prepared = 0;
967   TF_LITE_ENSURE_STATUS(
968       PrepareOpsStartingAt(next_execution_plan_index_to_prepare_,
969                            execution_plan_, &last_exec_plan_index_prepared));
970   next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1;
971 
972   // Execute arena allocations.
973   TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations(
974       next_execution_plan_index_to_plan_allocation_,
975       last_exec_plan_index_prepared));
976 
977   // Ensure custom allocations are still valid for applicable tensors.
978   // This causes some extra validations for cases with dynamic tensors, but the
979   // overhead should be minimal since the number of custom-allocated tensors
980   // will typically be low.
981   for (int i = 0; i < custom_allocations_.size(); ++i) {
982     auto index_and_alloc = custom_allocations_[i];
983     TfLiteTensor* tensor_at_index = tensor(index_and_alloc.first);
984     const auto& alloc = index_and_alloc.second;
985     TF_LITE_ENSURE(context(),
986                    tensor_at_index->allocation_type == kTfLiteCustom);
987     TF_LITE_ENSURE_STATUS(
988         ValidateCustomAllocationForTensor(context(), tensor_at_index, alloc));
989   }
990 
991   next_execution_plan_index_to_plan_allocation_ =
992       last_exec_plan_index_prepared + 1;
993 
994   return kTfLiteOk;
995 }
996 
Invoke()997 TfLiteStatus Subgraph::Invoke() {
998   if (!consistent_) {
999     ReportError("Invoke called on model that is not consistent.");
1000     return kTfLiteError;
1001   }
1002 
1003   TfLiteStatus status = kTfLiteOk;
1004   if (state_ == kStateUninvokable) {
1005     ReportError("Invoke called on model that is not ready.");
1006     return kTfLiteError;
1007   } else if (memory_planner_ && !memory_planner_->HasNonPersistentMemory()) {
1008     ReportError("Non-persistent memory is not available.");
1009     return kTfLiteError;
1010   }
1011 
1012   // Invocations are always done in node order.
1013   // Note that calling Invoke repeatedly will cause the original memory plan to
1014   // be reused, unless either ResizeInputTensor() or AllocateTensors() has been
1015   // called.
1016   for (int execution_plan_index = 0;
1017        execution_plan_index < execution_plan_.size(); execution_plan_index++) {
1018     if (execution_plan_index == next_execution_plan_index_to_prepare_) {
1019       TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
1020       TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >=
1021                                     execution_plan_index);
1022     }
1023     int node_index = execution_plan_[execution_plan_index];
1024     TfLiteNode& node = nodes_and_registration_[node_index].first;
1025     const TfLiteRegistration& registration =
1026         nodes_and_registration_[node_index].second;
1027 
1028     const char* op_name = nullptr;
1029     if (profiler_) op_name = GetTFLiteOpName(registration);
1030     TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE(profiler_.get(), op_name, node_index);
1031 
1032     // TODO(ycling): This is an extra loop through inputs to check if the data
1033     // need to be copied from Delegate buffer to raw memory, which is often not
1034     // needed. We may want to cache this in prepare to know if this needs to be
1035     // done for a node or not.
1036     for (int i = 0; i < node.inputs->size; ++i) {
1037       int tensor_index = node.inputs->data[i];
1038       if (tensor_index == kTfLiteOptionalTensor) {
1039         continue;
1040       }
1041       TfLiteTensor* tensor = &tensors_[tensor_index];
1042       if (tensor->delegate && tensor->delegate != node.delegate &&
1043           tensor->data_is_stale) {
1044         TF_LITE_ENSURE_STATUS(EnsureTensorDataIsReadable(tensor_index));
1045       }
1046       if (tensor->data.raw == nullptr && tensor->bytes > 0) {
1047         if (registration.builtin_code == kTfLiteBuiltinReshape && i == 1) {
1048           // In general, having a tensor here with no buffer will be an error.
1049           // However, for the reshape operator, the second input tensor is only
1050           // used for the shape, not for the data. Thus, null buffer is ok.
1051           continue;
1052         } else {
1053           // In all other cases, we need to return an error as otherwise we will
1054           // trigger a null pointer dereference (likely).
1055           ReportError("Input tensor %d lacks data", tensor_index);
1056           return kTfLiteError;
1057         }
1058       }
1059     }
1060 
1061     if (check_cancelled_func_ != nullptr &&
1062         check_cancelled_func_(cancellation_data_)) {
1063       ReportError("Client requested cancel during Invoke()");
1064       return kTfLiteError;
1065     }
1066 
1067     EnsureTensorsVectorCapacity();
1068     tensor_resized_since_op_invoke_ = false;
1069     if (OpInvoke(registration, &node) != kTfLiteOk) {
1070       return ReportOpError(&context_, node, registration, node_index,
1071                            "failed to invoke");
1072     }
1073 
1074     // Force execution prep for downstream ops if the latest op triggered the
1075     // resize of a dynamic tensor.
1076     if (tensor_resized_since_op_invoke_ &&
1077         HasDynamicTensor(context_, node.outputs)) {
1078       next_execution_plan_index_to_prepare_ = execution_plan_index + 1;
1079 
1080       // This happens when an intermediate dynamic tensor is resized.
1081       // We don't have to prepare all the ops, but we need to recompute
1082       // the allocation plan.
1083       if (next_execution_plan_index_to_plan_allocation_ >
1084           next_execution_plan_index_to_prepare_) {
1085         next_execution_plan_index_to_plan_allocation_ =
1086             next_execution_plan_index_to_prepare_;
1087         if (memory_planner_) {
1088           TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocationsAfter(
1089               next_execution_plan_index_to_plan_allocation_ - 1));
1090         }
1091       }
1092     }
1093   }
1094 
1095   return status;
1096 }
1097 
ResizeTensor(TfLiteContext * context,TfLiteTensor * tensor,TfLiteIntArray * new_size)1098 TfLiteStatus Subgraph::ResizeTensor(TfLiteContext* context,
1099                                     TfLiteTensor* tensor,
1100                                     TfLiteIntArray* new_size) {
1101   // If the dimensions don't change, avoiding
1102   // unnecessary (re)allocations.
1103   //
1104   // Note that it's required to check `tensor->data.raw != nullptr`. Otherwise
1105   // the subgraph won't allocate memory for a dynamic tensor when its size
1106   // is equal to the original tensor size.
1107   if (tensor->data.raw != nullptr &&
1108       EqualArrayAndTfLiteIntArray(tensor->dims, new_size->size,
1109                                   new_size->data)) {
1110     // A number of clients assume |new_size| remains valid upon success, so
1111     // swap it in as the new (but logically identical) tensor dims.
1112     TfLiteIntArrayFree(tensor->dims);
1113     tensor->dims = new_size;
1114     return kTfLiteOk;
1115   }
1116 
1117   // Note here that context->impl_ is recovering the this pointer for an
1118   // instance of Interpreter to call into the member function ResizeTensorImpl
1119   // (this function is static).
1120   return static_cast<Subgraph*>(context->impl_)
1121       ->ResizeTensorImpl(tensor, new_size);
1122 }
1123 
ReportErrorImpl(const char * format,va_list args)1124 void Subgraph::ReportErrorImpl(const char* format, va_list args) {
1125   error_reporter_->Report(format, args);
1126 }
1127 
ReportErrorC(TfLiteContext * context,const char * format,...)1128 void Subgraph::ReportErrorC(TfLiteContext* context, const char* format, ...) {
1129   va_list args;
1130   va_start(args, format);
1131   auto* f = static_cast<Subgraph*>(context->impl_);
1132   // Note here that context->impl_ is recovering the this pointer for an
1133   // instance of Subgraph to call into the member function ReportErrorImpl
1134   // (this function is static).
1135   f->ReportErrorImpl(format, args);
1136   va_end(args);
1137 }
1138 
1139 // Entry point for C node plugin API to report an error.
ReportError(const char * format,...)1140 void Subgraph::ReportError(const char* format, ...) {
1141   va_list args;
1142   va_start(args, format);
1143   auto* f = static_cast<Subgraph*>(context_.impl_);
1144   // Note here that context->impl_ is recovering the this pointer for an
1145   // instance of Subgraph to call into the member function ReportErrorImpl
1146   // (this function is static).
1147   f->ReportErrorImpl(format, args);
1148   va_end(args);
1149 }
1150 
AddTensors(int tensors_to_add,int * first_new_tensor_index)1151 TfLiteStatus Subgraph::AddTensors(int tensors_to_add,
1152                                   int* first_new_tensor_index) {
1153   const size_t base_index = tensors_.size();
1154   if (first_new_tensor_index) *first_new_tensor_index = base_index;
1155   tensors_.resize(tensors_.size() + tensors_to_add);
1156   for (size_t i = base_index; i < tensors_.size(); i++) {
1157     memset(&tensors_[i], 0, sizeof(tensors_[i]));
1158     tensors_[i].buffer_handle = kTfLiteNullBufferHandle;
1159   }
1160   context_.tensors = tensors_.data();
1161   context_.tensors_size = tensors_.size();
1162   return kTfLiteOk;
1163 }
1164 
AddTensors(TfLiteContext * context,int tensors_to_add,int * first_new_tensor_index)1165 TfLiteStatus Subgraph::AddTensors(TfLiteContext* context, int tensors_to_add,
1166                                   int* first_new_tensor_index) {
1167   // Note here that context->impl_ is recovering the this pointer for an
1168   // instance of Interpreter to call into the member function AddTensors
1169   // (this function is static).
1170   return static_cast<Subgraph*>(context->impl_)
1171       ->AddTensors(tensors_to_add, first_new_tensor_index);
1172 }
1173 
GetNodeAndRegistration(int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)1174 TfLiteStatus Subgraph::GetNodeAndRegistration(
1175     int node_index, TfLiteNode** node, TfLiteRegistration** registration) {
1176   TF_LITE_ENSURE(&context_, node_index >= 0);
1177   auto nodes_size = nodes_and_registration_.size();
1178   TF_LITE_ENSURE(&context_, static_cast<size_t>(node_index) < nodes_size);
1179   TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr);
1180   auto& node_and_reg = nodes_and_registration_[node_index];
1181   *node = &node_and_reg.first;
1182   *registration = &node_and_reg.second;
1183   return kTfLiteOk;
1184 }
1185 
GetNodeAndRegistration(struct TfLiteContext * context,int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)1186 TfLiteStatus Subgraph::GetNodeAndRegistration(
1187     struct TfLiteContext* context, int node_index, TfLiteNode** node,
1188     TfLiteRegistration** registration) {
1189   return static_cast<Subgraph*>(context->impl_)
1190       ->GetNodeAndRegistration(node_index, node, registration);
1191 }
1192 
SetTensorParametersReadOnly(int tensor_index,TfLiteType type,const char * name,const size_t rank,const int * dims,TfLiteQuantization quantization,const char * buffer,size_t bytes,const Allocation * allocation,TfLiteSparsity * sparsity)1193 TfLiteStatus Subgraph::SetTensorParametersReadOnly(
1194     int tensor_index, TfLiteType type, const char* name, const size_t rank,
1195     const int* dims, TfLiteQuantization quantization, const char* buffer,
1196     size_t bytes, const Allocation* allocation, TfLiteSparsity* sparsity) {
1197   // Ensure quantization cleanup on failure.
1198   ScopedTfLiteQuantization scoped_quantization(&quantization);
1199   ScopedTfLiteSparsity scoped_sparsity(sparsity);
1200   if (state_ == kStateInvokableAndImmutable) {
1201     ReportError(
1202         "SetTensorParametersReadOnly is disallowed when graph is immutable.");
1203     return kTfLiteError;
1204   }
1205 
1206   TF_LITE_ENSURE(&context_,
1207                  tensor_index < context_.tensors_size && tensor_index >= 0);
1208 
1209   // For most tensors we know exactly how much memory is necessary so we can
1210   // ensure the buffer is large enough. However, we need to skip string tensors
1211   // and sparse tensors because their sizes change with the contents.
1212   // TODO(b/145615516): Extend BytesRequired to check sparse tensors.
1213   if (type != kTfLiteString && type != kTfLiteResource &&
1214       type != kTfLiteVariant && sparsity == nullptr) {
1215     size_t required_bytes;
1216     TF_LITE_ENSURE_OK(&context_,
1217                       BytesRequired(type, dims, rank, &required_bytes));
1218     TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
1219   }
1220 
1221   TfLiteTensor& tensor = context_.tensors[tensor_index];
1222   if (type == tensor.type &&
1223       EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
1224     // Fast path which does not invalidate the invokable property.
1225     TfLiteTensorDataFree(&tensor);
1226     TfLiteQuantizationFree(&tensor.quantization);
1227     tensor.data.raw = const_cast<char*>(buffer);
1228     if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
1229     tensor.params = GetLegacyQuantization(quantization);
1230     tensor.quantization = *scoped_quantization.release();
1231     tensor.sparsity = scoped_sparsity.release();
1232     tensor.allocation_type = kTfLiteMmapRo;
1233     tensor.allocation = allocation;
1234   } else {
1235     state_ = kStateUninvokable;
1236     TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
1237                       GetLegacyQuantization(quantization),
1238                       const_cast<char*>(buffer), bytes, kTfLiteMmapRo,
1239                       allocation, false, &tensor);
1240     // TODO(suharshs): Update TfLiteTensorReset to include the new quantization
1241     // if there are other required callers.
1242     tensor.quantization = *scoped_quantization.release();
1243     tensor.sparsity = scoped_sparsity.release();
1244   }
1245   return kTfLiteOk;
1246 }
1247 
1248 // Set description of inputs/outputs/data/fptrs for node `node_index`.
1249 // This variant assumes an external buffer has been allocated of size
1250 // bytes. The lifetime of buffer must be ensured to be greater or equal
1251 // to Interpreter.
SetTensorParametersReadWrite(int tensor_index,TfLiteType type,const char * name,const size_t rank,const int * dims,TfLiteQuantization quantization,bool is_variable,const size_t rank_dims_signature,const int * dims_signature)1252 TfLiteStatus Subgraph::SetTensorParametersReadWrite(
1253     int tensor_index, TfLiteType type, const char* name, const size_t rank,
1254     const int* dims, TfLiteQuantization quantization, bool is_variable,
1255     const size_t rank_dims_signature, const int* dims_signature) {
1256   // Ensure quantization cleanup on failure.
1257   ScopedTfLiteQuantization scoped_quantization(&quantization);
1258   if (state_ == kStateInvokableAndImmutable) {
1259     ReportError(
1260         "SetTensorParametersReadWrite is disallowed when graph is immutable.");
1261     return kTfLiteError;
1262   }
1263   TF_LITE_ENSURE(&context_,
1264                  tensor_index < context_.tensors_size && tensor_index >= 0);
1265   size_t required_bytes = 0;
1266   if (type != kTfLiteString && type != kTfLiteResource &&
1267       type != kTfLiteVariant) {
1268     // These types will be allocated in our arena so we need to record how
1269     // many bytes we will need based on the dimensions. String tensors are
1270     // allocated dynamically and we can't know ahead of time how much space
1271     // they will require.
1272     TF_LITE_ENSURE_OK(&context_,
1273                       BytesRequired(type, dims, rank, &required_bytes));
1274   }
1275 
1276   TfLiteAllocationType allocation_type = kTfLiteArenaRw;
1277   if (type == kTfLiteString || type == kTfLiteResource ||
1278       type == kTfLiteVariant) {
1279     if (is_variable) {
1280       // We don't have a real use case for string variable tensor.
1281       ReportError("String variable tensor isn't supported.");
1282       return kTfLiteError;
1283     }
1284     allocation_type = kTfLiteDynamic;
1285   } else if (is_variable) {
1286     allocation_type = kTfLiteArenaRwPersistent;
1287   }
1288 
1289   TfLiteTensor& tensor = context_.tensors[tensor_index];
1290   TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
1291                     GetLegacyQuantization(quantization),
1292                     /*buffer=*/nullptr, required_bytes, allocation_type,
1293                     nullptr, is_variable, &tensor);
1294   // TODO(suharshs): Update TfLiteTensorReset to include the new quantization
1295   // if there are other required callers.
1296   tensor.quantization = *scoped_quantization.release();
1297   tensor.dims_signature =
1298       ConvertArrayToTfLiteIntArray(rank_dims_signature, dims_signature);
1299   return kTfLiteOk;
1300 }
1301 
SetExecutionPlan(const std::vector<int> & new_plan)1302 TfLiteStatus Subgraph::SetExecutionPlan(const std::vector<int>& new_plan) {
1303   for (int node_index : new_plan) {
1304     TF_LITE_ENSURE(&context_, node_index >= 0 &&
1305                                   node_index < nodes_and_registration_.size());
1306   }
1307   execution_plan_ = new_plan;
1308   return kTfLiteOk;
1309 }
1310 
ResizeTensorImpl(TfLiteTensor * tensor,TfLiteIntArray * new_size)1311 TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor,
1312                                         TfLiteIntArray* new_size) {
1313   // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
1314   if (tensor->allocation_type == kTfLiteArenaRw ||
1315       tensor->allocation_type == kTfLiteDynamic ||
1316       tensor->allocation_type == kTfLiteArenaRwPersistent ||
1317       tensor->allocation_type == kTfLitePersistentRo ||
1318       tensor->allocation_type == kTfLiteCustom) {
1319     tensor_resized_since_op_invoke_ |=
1320         TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
1321     if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
1322         tensor->type != kTfLiteVariant) {
1323       size_t bytesRequired;
1324       TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
1325                                           new_size->size, &bytesRequired);
1326       if (status != kTfLiteOk) {
1327         TfLiteIntArrayFree(new_size);
1328         return kTfLiteError;
1329       }
1330 
1331       // Realloc space for heap-allocated tensors.
1332       TfLiteTensorRealloc(bytesRequired, tensor);
1333       tensor->bytes = bytesRequired;
1334     }
1335     if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
1336     tensor->dims = new_size;
1337 
1338     // Reset arena-allocated tensors; they will be allocated later.
1339     if (tensor->allocation_type == kTfLiteArenaRw ||
1340         tensor->allocation_type == kTfLiteArenaRwPersistent) {
1341       tensor->data.raw = nullptr;
1342     }
1343   } else {
1344     // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore
1345     // of fixed size.
1346     TfLiteIntArrayFree(new_size);
1347     ReportError("Attempting to resize a fixed-size tensor.");
1348     return kTfLiteError;
1349   }
1350   return kTfLiteOk;
1351 }
1352 
SwitchToDelegateContext()1353 void Subgraph::SwitchToDelegateContext() {
1354   context_.GetNodeAndRegistration = GetNodeAndRegistration;
1355   context_.ReplaceNodeSubsetsWithDelegateKernels =
1356       ReplaceNodeSubsetsWithDelegateKernels;
1357   context_.GetExecutionPlan = GetExecutionPlan;
1358   context_.PreviewDelegatePartitioning = PreviewDelegatePartitioning;
1359 }
1360 
SwitchToKernelContext()1361 void Subgraph::SwitchToKernelContext() {
1362   context_.GetNodeAndRegistration = [](struct TfLiteContext* context,
1363                                        int node_index, TfLiteNode** node,
1364                                        TfLiteRegistration** registration) {
1365     return ForbiddenContextFunction(context);
1366   };
1367   context_.ReplaceNodeSubsetsWithDelegateKernels =
1368       [](TfLiteContext* context, TfLiteRegistration registration,
1369          const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) {
1370         return ForbiddenContextFunction(context);
1371       };
1372   context_.GetExecutionPlan = [](struct TfLiteContext* context,
1373                                  TfLiteIntArray**) {
1374     return ForbiddenContextFunction(context);
1375   };
1376   context_.PreviewDelegatePartitioning =
1377       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
1378          TfLiteDelegateParams** partition_params_array,
1379          int* num_partitions) { return ForbiddenContextFunction(context); };
1380   // Free any memory that might have been allocated by
1381   // PreviewDelegatePartitioning.
1382   FreeDelegatePartitioningData();
1383 }
1384 
UndoAllDelegates()1385 TfLiteStatus Subgraph::UndoAllDelegates() {
1386   // Return early if there is nothing to reset to.
1387   if (pre_delegation_execution_plan_.empty()) return kTfLiteOk;
1388 
1389   // First free all delegate nodes.
1390   for (int execution_plan_index = 0;
1391        execution_plan_index < execution_plan_.size(); ++execution_plan_index) {
1392     int node_index = execution_plan_[execution_plan_index];
1393     TfLiteNode& node = nodes_and_registration_[node_index].first;
1394     if (node.delegate == nullptr) {
1395       continue;
1396     }
1397     CleanupNode(node_index);
1398   }
1399 
1400   // Reset execution plan.
1401   execution_plan_ = pre_delegation_execution_plan_;
1402   pre_delegation_execution_plan_.clear();
1403 
1404   // Handling FP16 delegation (if applies).
1405   //
1406   // First pass through execution plan to remember mapping of FP16
1407   // dequantizations in the graph.
1408   // This is required because delegates that support FP16 could remap supported
1409   // nodes' inputs to point to their fp16 versions (if delegate supports fp16
1410   // acceleration). This remapping is performed in FP16GraphPartitionHelper in
1411   // delegates/utils. We need to undo this remapping to ensure CPU kernels work.
1412   std::vector<int> fp16_to_fp32(tensors_size(), -1);
1413   for (int execution_plan_index = 0;
1414        execution_plan_index < execution_plan_.size(); ++execution_plan_index) {
1415     int node_index = execution_plan_[execution_plan_index];
1416     auto& node_and_reg = nodes_and_registration_[node_index];
1417     const TfLiteNode& node = node_and_reg.first;
1418     const TfLiteRegistration& reg = node_and_reg.second;
1419     if (reg.builtin_code == kTfLiteBuiltinDequantize &&
1420         node.inputs->size == 1 && node.outputs->size == 1) {
1421       const int input_idx = node.inputs->data[0];
1422       if (tensors_[input_idx].type == kTfLiteFloat16) {
1423         fp16_to_fp32[input_idx] = node.outputs->data[0];
1424       }
1425     }
1426   }
1427   // Second pass through the execution plan to remap applicable nodes' fp16
1428   // inputs to their original fp32 versions. Note that if a CPU kernel does
1429   // support fp16, the model will not contain a DEQUANTIZE for its constant
1430   // input.
1431   for (int execution_plan_index = 0;
1432        execution_plan_index < execution_plan_.size(); ++execution_plan_index) {
1433     int node_index = execution_plan_[execution_plan_index];
1434     auto& node_and_reg = nodes_and_registration_[node_index];
1435     const TfLiteNode& node = node_and_reg.first;
1436     const TfLiteRegistration& reg = node_and_reg.second;
1437     if (reg.builtin_code == kTfLiteBuiltinDequantize) continue;
1438     for (int i = 0; i < node.inputs->size; ++i) {
1439       const int original_input_idx = node.inputs->data[i];
1440       if (tensors_[original_input_idx].type == kTfLiteFloat16) {
1441         node.inputs->data[i] = fp16_to_fp32[original_input_idx];
1442       }
1443     }
1444   }
1445 
1446   // Delegate nodes are appended to nodes_and_registration_. Therefore,
1447   // cleanup nodes_and_registration_ to only contain nodes from
1448   // pre_delegation_execution_plan_.
1449   int max_retained_node_index = 0;
1450   for (int execution_plan_index = 0;
1451        execution_plan_index < execution_plan_.size(); ++execution_plan_index) {
1452     max_retained_node_index = std::max(max_retained_node_index,
1453                                        execution_plan_[execution_plan_index]);
1454   }
1455   nodes_and_registration_.resize(max_retained_node_index + 1);
1456   // After undoing delegates, the graph is uninvokable, but mutable.
1457   state_ = kStateUninvokable;
1458 
1459   delegates_undone_ = true;
1460   return kTfLiteOk;
1461 }
1462 
RedoAllDelegates()1463 TfLiteStatus Subgraph::RedoAllDelegates() {
1464   if (!delegates_undone_) return kTfLiteOk;
1465 
1466   delegates_undone_ = false;
1467   std::vector<TfLiteDelegate*> delegates_to_apply;
1468   delegates_applied_.swap(delegates_to_apply);
1469   for (auto* delegate : delegates_to_apply) {
1470     TF_LITE_ENSURE_STATUS(ModifyGraphWithDelegate(delegate));
1471   }
1472   return kTfLiteOk;
1473 }
1474 
RemoveAllDelegates()1475 TfLiteStatus Subgraph::RemoveAllDelegates() {
1476   TF_LITE_ENSURE_STATUS(UndoAllDelegates());
1477   delegates_applied_.clear();
1478   delegates_undone_ = false;
1479   TF_LITE_ENSURE_STATUS(EnsureMemoryAllocations());
1480   return kTfLiteOk;
1481 }
1482 
HasDelegates()1483 bool Subgraph::HasDelegates() { return !delegates_applied_.empty(); }
1484 
EnsureTensorsVectorCapacity()1485 void Subgraph::EnsureTensorsVectorCapacity() {
1486   const size_t required_capacity = tensors_.size() + kTensorsCapacityHeadroom;
1487   if (required_capacity > tensors_.capacity()) {
1488     // Whenever it's required to increase the vector capacity, make it at
1489     // least twice bigger. The behavior is consistent with the default
1490     // behavior of GCC STL's `std::vector::resize()`. This avoids frequently
1491     // allocating and copying the underlying buffer.
1492     size_t reserved_capacity =
1493         std::max(required_capacity, tensors_.capacity() * 2);
1494     tensors_.reserve(reserved_capacity);
1495     context_.tensors = tensors_.data();
1496   }
1497 }
1498 
EnsureMemoryAllocations()1499 TfLiteStatus Subgraph::EnsureMemoryAllocations() {
1500   if (memory_planner_) {
1501     state_ = kStateUninvokable;
1502     TF_LITE_ENSURE_OK(&context_, memory_planner_->PlanAllocations());
1503   }
1504   TF_LITE_ENSURE_OK(&context_, AllocateTensors());
1505   TF_LITE_ENSURE_EQ(&context_, state_, kStateInvokable);
1506   return kTfLiteOk;
1507 }
1508 
ModifyGraphWithDelegate(TfLiteDelegate * delegate)1509 TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
1510   TFLITE_SCOPED_TAGGED_DEFAULT_PROFILE(profiler_.get(),
1511                                        "ModifyGraphWithDelegate");
1512 
1513   if (delegate == nullptr) {
1514     ReportError("Null delegate.");
1515     return kTfLiteDelegateError;
1516   }
1517 
1518   // Restore delegation state if applicable.
1519   TF_LITE_ENSURE_STATUS(RedoAllDelegates());
1520 
1521   if (state_ == kStateInvokableAndImmutable) {
1522     ReportError(
1523         "ModifyGraphWithDelegate is disallowed when graph is immutable.");
1524     return kTfLiteApplicationError;
1525   }
1526 
1527   if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
1528     int last_execution_plan_index_prepared;
1529     TF_LITE_ENSURE_OK(
1530         &context_, PrepareOpsStartingAt(0, execution_plan_,
1531                                         &last_execution_plan_index_prepared));
1532     if (has_dynamic_tensors_) {
1533       // Make sure that we are in a defined ready state before returning.
1534       // Plan and allocate tensors before returning.
1535       TF_LITE_ENSURE_OK(&context_, EnsureMemoryAllocations());
1536       ReportError(
1537           "Attempting to use a delegate that only supports static-sized "
1538           "tensors with a graph that has dynamic-sized tensors.");
1539       return kTfLiteApplicationError;
1540     }
1541   }
1542 
1543   const bool was_invokable_before_delegate = state_ == kStateInvokable;
1544   if (delegates_applied_.empty()) {
1545     // This is the first delegate being applied, so remember original execution
1546     // plan.
1547     // TODO(b/119623453): Restore execution plan to this state if delegate
1548     // application fails.
1549     pre_delegation_execution_plan_ = execution_plan_;
1550   }
1551 
1552   // TODO(aselle): Consider if it is worth storing pointers to delegates.
1553   // Setup additional context interface.
1554   SwitchToDelegateContext();
1555 
1556   auto reset_delegation_if_not_ok = [this](TfLiteStatus status) {
1557     if (status != kTfLiteOk) {
1558       TF_LITE_ENSURE_STATUS(RemoveAllDelegates());
1559       ReportError(
1560           "Restored original execution plan after delegate application "
1561           "failure.");
1562       return kTfLiteDelegateError;
1563     }
1564     return kTfLiteOk;
1565   };
1566 
1567   TfLiteStatus status = delegate->Prepare(&context_, delegate);
1568 
1569   // Remove additional context info.
1570   SwitchToKernelContext();
1571 
1572   TF_LITE_ENSURE_STATUS(reset_delegation_if_not_ok(status));
1573 
1574   if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
1575     // Reset the state to force tensor/op reallocation.
1576     state_ = kStateUninvokable;
1577     TF_LITE_ENSURE_STATUS(
1578         reset_delegation_if_not_ok(EnsureMemoryAllocations()));
1579     // After using a delegate which doesn't support dynamic tensors, make the
1580     // entire graph immutable.
1581     state_ = kStateInvokableAndImmutable;
1582   } else if (was_invokable_before_delegate) {
1583     // If the graph was invokable prior to delegate application, flush
1584     // allocation now to leave it in a consistent state.
1585     TF_LITE_ENSURE_STATUS(
1586         reset_delegation_if_not_ok(EnsureMemoryAllocations()));
1587   }
1588   delegates_applied_.push_back(delegate);
1589 
1590   return status;
1591 }
1592 
SetCustomAllocationForTensor(int tensor_index,const TfLiteCustomAllocation & allocation)1593 TfLiteStatus Subgraph::SetCustomAllocationForTensor(
1594     int tensor_index, const TfLiteCustomAllocation& allocation) {
1595   TfLiteTensor* tensor = &context_.tensors[tensor_index];
1596   TF_LITE_ENSURE(context(),
1597                  (tensor->allocation_type == kTfLiteArenaRw ||
1598                   tensor->allocation_type == kTfLiteArenaRwPersistent ||
1599                   tensor->allocation_type == kTfLiteCustom));
1600   TF_LITE_ENSURE_STATUS(
1601       ValidateCustomAllocationForTensor(context(), tensor, allocation));
1602 
1603   // If tensor already has a custom alloc, just reassign.
1604   const auto alloc_it = std::find_if(
1605       custom_allocations_.begin(), custom_allocations_.end(),
1606       [tensor_index](
1607           const std::pair<int, TfLiteCustomAllocation>& existing_alloc) {
1608         return existing_alloc.first == tensor_index;
1609       });
1610   if (alloc_it == custom_allocations_.end()) {
1611     custom_allocations_.emplace_back(tensor_index, allocation);
1612   } else {
1613     alloc_it->second = allocation;
1614   }
1615 
1616   tensor->allocation_type = kTfLiteCustom;
1617   tensor->data.data = allocation.data;
1618 
1619   return kTfLiteOk;
1620 }
1621 
SetName(const char * name)1622 void Subgraph::SetName(const char* name) {
1623   if (name) {
1624     name_ = name;
1625   } else {
1626     name_ = "";
1627   }
1628 }
1629 
GetName() const1630 const std::string& Subgraph::GetName() const { return name_; }
1631 
1632 }  // namespace tflite
1633