/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Contains classes that can execute different models/parts of a model. #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ #include #include #include "utils/base/logging.h" #include "utils/tensor-view.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/mutable_op_resolver.h" #include "tensorflow/lite/op_resolver.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { // Creates a TF.Lite Op resolver in default configuration, with ops for // Annotator and Actions models. std::unique_ptr BuildOpResolver(); // Like above, but allows passage of a function that can register additional // ops. std::unique_ptr BuildOpResolver( const std::function& customize_fn); std::unique_ptr TfLiteModelFromModelSpec( const tflite::Model*); std::unique_ptr TfLiteModelFromBuffer( const flatbuffers::Vector*); // Executor for the text selection prediction and classification models. class TfLiteModelExecutor { public: static std::unique_ptr FromModelSpec( const tflite::Model* model_spec) { auto model = TfLiteModelFromModelSpec(model_spec); if (!model) { return nullptr; } return std::unique_ptr( new TfLiteModelExecutor(std::move(model))); } static std::unique_ptr FromBuffer( const flatbuffers::Vector* model_spec_buffer) { auto model = TfLiteModelFromBuffer(model_spec_buffer); if (!model) { return nullptr; } return std::unique_ptr( new TfLiteModelExecutor(std::move(model))); } // Creates an Interpreter for the model that serves as a scratch-pad for the // inference. The Interpreter is NOT thread-safe. std::unique_ptr CreateInterpreter() const; template void SetInput(const int input_index, const TensorView& input_data, tflite::Interpreter* interpreter) const { input_data.copy_to(interpreter->typed_input_tensor(input_index), input_data.size()); } template void SetInput(const int input_index, const std::vector& input_data, tflite::Interpreter* interpreter) const { std::copy(input_data.begin(), input_data.end(), interpreter->typed_input_tensor(input_index)); } template void SetInput(const int input_index, const T input_value, tflite::Interpreter* interpreter) const { TfLiteTensor* input_tensor = interpreter->tensor(interpreter->inputs()[input_index]); switch (input_tensor->type) { case kTfLiteFloat32: *tflite::GetTensorData(input_tensor) = input_value; break; case kTfLiteInt32: *tflite::GetTensorData(input_tensor) = input_value; break; case kTfLiteUInt8: *tflite::GetTensorData(input_tensor) = input_value; break; case kTfLiteInt64: *tflite::GetTensorData(input_tensor) = input_value; break; case kTfLiteBool: *tflite::GetTensorData(input_tensor) = input_value; break; case kTfLiteInt16: *tflite::GetTensorData(input_tensor) = input_value; break; case kTfLiteInt8: *tflite::GetTensorData(input_tensor) = input_value; break; default: break; } } template TensorView OutputView(const int output_index, const tflite::Interpreter* interpreter) const { const TfLiteTensor* output_tensor = interpreter->tensor(interpreter->outputs()[output_index]); return TensorView(interpreter->typed_output_tensor(output_index), std::vector(output_tensor->dims->data, output_tensor->dims->data + output_tensor->dims->size)); } template std::vector Output(const int output_index, const tflite::Interpreter* interpreter) const { TensorView output_view = OutputView(output_index, interpreter); return std::vector(output_view.data(), output_view.data() + output_view.size()); } protected: explicit TfLiteModelExecutor( std::unique_ptr model); TfLiteModelExecutor(std::unique_ptr model, std::unique_ptr resolver); std::unique_ptr model_; std::unique_ptr resolver_; }; template <> void TfLiteModelExecutor::SetInput(const int input_index, const std::vector& input_data, tflite::Interpreter* interpreter) const; template <> std::vector TfLiteModelExecutor::Output( const int output_index, const tflite::Interpreter* interpreter) const; template <> std::vector TfLiteModelExecutor::Output( const int output_index, const tflite::Interpreter* interpreter) const; } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_