1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
18 
19 #include <utility>
20 
21 #include "absl/status/status.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow_lite_support/cc/common.h"
25 #include "tensorflow_lite_support/cc/port/status_macros.h"
26 #include "tensorflow_lite_support/cc/port/statusor.h"
27 #include "tensorflow_lite_support/cc/port/tflite_wrapper.h"
28 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
29 
30 namespace tflite {
31 namespace task {
32 namespace core {
33 
34 class BaseUntypedTaskApi {
35  public:
BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine)36   explicit BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine)
37       : engine_{std::move(engine)} {}
38 
39   virtual ~BaseUntypedTaskApi() = default;
40 
GetTfLiteEngine()41   TfLiteEngine* GetTfLiteEngine() { return engine_.get(); }
GetTfLiteEngine()42   const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); }
43 
GetMetadataExtractor()44   const metadata::ModelMetadataExtractor* GetMetadataExtractor() const {
45     return engine_->metadata_extractor();
46   }
47 
48  protected:
49   std::unique_ptr<TfLiteEngine> engine_;
50 };
51 
52 template <class OutputType, class... InputTypes>
53 class BaseTaskApi : public BaseUntypedTaskApi {
54  public:
BaseTaskApi(std::unique_ptr<TfLiteEngine> engine)55   explicit BaseTaskApi(std::unique_ptr<TfLiteEngine> engine)
56       : BaseUntypedTaskApi(std::move(engine)) {}
57   // BaseTaskApi is neither copyable nor movable.
58   BaseTaskApi(const BaseTaskApi&) = delete;
59   BaseTaskApi& operator=(const BaseTaskApi&) = delete;
60 
61   // Cancels the current running TFLite invocation on CPU.
62   //
63   // Usually called on a different thread than the one inference is running on.
64   // Calling Cancel() will cause the underlying TFLite interpreter to return an
65   // error, which will turn into a `CANCELLED` status and empty results. Calling
66   // Cancel() at the other time will not take any effect on the current or
67   // following invocation. It is perfectly fine to run inference again on the
68   // same instance after a cancelled invocation. If the TFLite inference is
69   // partially delegated on CPU, logs a warning message and only cancels the
70   // invocation running on CPU. Other invocation which depends on the output of
71   // the CPU invocation will not be executed.
Cancel()72   void Cancel() { engine_->Cancel(); }
73 
74  protected:
75   // Subclasses need to populate input_tensors from api_inputs.
76   virtual absl::Status Preprocess(
77       const std::vector<TfLiteTensor*>& input_tensors,
78       InputTypes... api_inputs) = 0;
79 
80   // Subclasses need to construct OutputType object from output_tensors.
81   // Original inputs are also provided as they may be needed.
82   virtual tflite::support::StatusOr<OutputType> Postprocess(
83       const std::vector<const TfLiteTensor*>& output_tensors,
84       InputTypes... api_inputs) = 0;
85 
86   // Returns (the addresses of) the model's inputs.
GetInputTensors()87   std::vector<TfLiteTensor*> GetInputTensors() { return engine_->GetInputs(); }
88 
89   // Returns (the addresses of) the model's outputs.
GetOutputTensors()90   std::vector<const TfLiteTensor*> GetOutputTensors() {
91     return engine_->GetOutputs();
92   }
93 
94   // Performs inference using tflite::support::TfLiteInterpreterWrapper
95   // InvokeWithoutFallback().
Infer(InputTypes...args)96   tflite::support::StatusOr<OutputType> Infer(InputTypes... args) {
97     tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper =
98         engine_->interpreter_wrapper();
99     // Note: AllocateTensors() is already performed by the interpreter wrapper
100     // at InitInterpreter time (see TfLiteEngine).
101     RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...));
102     absl::Status status = interpreter_wrapper->InvokeWithoutFallback();
103     if (!status.ok()) {
104       return status.GetPayload(tflite::support::kTfLiteSupportPayload)
105                      .has_value()
106                  ? status
107                  : tflite::support::CreateStatusWithPayload(status.code(),
108                                                             status.message());
109     }
110     return Postprocess(GetOutputTensors(), args...);
111   }
112 
113   // Performs inference using tflite::support::TfLiteInterpreterWrapper
114   // InvokeWithFallback() to benefit from automatic fallback from delegation to
115   // CPU where applicable.
InferWithFallback(InputTypes...args)116   tflite::support::StatusOr<OutputType> InferWithFallback(InputTypes... args) {
117     tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper =
118         engine_->interpreter_wrapper();
119     // Note: AllocateTensors() is already performed by the interpreter wrapper
120     // at InitInterpreter time (see TfLiteEngine).
121     RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...));
122     auto set_inputs_nop =
123         [](tflite::task::core::TfLiteEngine::Interpreter* interpreter)
124         -> absl::Status {
125       // NOP since inputs are populated at Preprocess() time.
126       return absl::OkStatus();
127     };
128     absl::Status status =
129         interpreter_wrapper->InvokeWithFallback(set_inputs_nop);
130     if (!status.ok()) {
131       return status.GetPayload(tflite::support::kTfLiteSupportPayload)
132                      .has_value()
133                  ? status
134                  : tflite::support::CreateStatusWithPayload(status.code(),
135                                                             status.message());
136     }
137     return Postprocess(GetOutputTensors(), args...);
138   }
139 };
140 
141 }  // namespace core
142 }  // namespace task
143 }  // namespace tflite
144 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
145