1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains classes that can execute different models/parts of a model. 18 19 #ifndef LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ 20 #define LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ 21 22 #include <memory> 23 24 #include "tensor-view.h" 25 #include "types.h" 26 #include "util/base/logging.h" 27 #include "tensorflow/contrib/lite/interpreter.h" 28 #include "tensorflow/contrib/lite/kernels/register.h" 29 #include "tensorflow/contrib/lite/model.h" 30 31 namespace libtextclassifier2 { 32 33 namespace internal { 34 bool FromModelSpec(const tflite::Model* model_spec, 35 std::unique_ptr<const tflite::FlatBufferModel>* model); 36 } // namespace internal 37 38 // A helper function that given indices of feature and logits tensor, feature 39 // values computes the logits using given interpreter. 40 TensorView<float> ComputeLogitsHelper(const int input_index_features, 41 const int output_index_logits, 42 const TensorView<float>& features, 43 tflite::Interpreter* interpreter); 44 45 // Executor for the text selection prediction and classification models. 46 class ModelExecutor { 47 public: Instance(const flatbuffers::Vector<uint8_t> * model_spec_buffer)48 static std::unique_ptr<const ModelExecutor> Instance( 49 const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 50 const tflite::Model* model = 51 flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); 52 flatbuffers::Verifier verifier(model_spec_buffer->data(), 53 model_spec_buffer->Length()); 54 if (!model->Verify(verifier)) { 55 return nullptr; 56 } 57 return Instance(model); 58 } 59 Instance(const tflite::Model * model_spec)60 static std::unique_ptr<const ModelExecutor> Instance( 61 const tflite::Model* model_spec) { 62 std::unique_ptr<const tflite::FlatBufferModel> model; 63 if (!internal::FromModelSpec(model_spec, &model)) { 64 return nullptr; 65 } 66 return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); 67 } 68 69 // Creates an Interpreter for the model that serves as a scratch-pad for the 70 // inference. The Interpreter is NOT thread-safe. 71 std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; 72 ComputeLogits(const TensorView<float> & features,tflite::Interpreter * interpreter)73 TensorView<float> ComputeLogits(const TensorView<float>& features, 74 tflite::Interpreter* interpreter) const { 75 return ComputeLogitsHelper(kInputIndexFeatures, kOutputIndexLogits, 76 features, interpreter); 77 } 78 79 protected: ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)80 explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model) 81 : model_(std::move(model)) {} 82 83 static const int kInputIndexFeatures = 0; 84 static const int kOutputIndexLogits = 0; 85 86 std::unique_ptr<const tflite::FlatBufferModel> model_; 87 tflite::ops::builtin::BuiltinOpResolver builtins_; 88 }; 89 90 // Executor for embedding sparse features into a dense vector. 91 class EmbeddingExecutor { 92 public: ~EmbeddingExecutor()93 virtual ~EmbeddingExecutor() {} 94 95 // Embeds the sparse_features into a dense embedding and adds (+) it 96 // element-wise to the dest vector. 97 virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 98 int dest_size) const = 0; 99 100 // Returns true when the model is ready to be used, false otherwise. IsReady()101 virtual bool IsReady() const { return true; } 102 }; 103 104 class TFLiteEmbeddingExecutor : public EmbeddingExecutor { 105 public: 106 static std::unique_ptr<TFLiteEmbeddingExecutor> Instance( 107 const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, 108 int quantization_bits); 109 110 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 111 int dest_size) const override; 112 113 protected: 114 explicit TFLiteEmbeddingExecutor( 115 std::unique_ptr<const tflite::FlatBufferModel> model, 116 int quantization_bits, int num_buckets, int bytes_per_embedding, 117 int output_embedding_size, const TfLiteTensor* scales, 118 const TfLiteTensor* embeddings, 119 std::unique_ptr<tflite::Interpreter> interpreter); 120 121 std::unique_ptr<const tflite::FlatBufferModel> model_; 122 123 int quantization_bits_; 124 int num_buckets_ = -1; 125 int bytes_per_embedding_ = -1; 126 int output_embedding_size_ = -1; 127 const TfLiteTensor* scales_ = nullptr; 128 const TfLiteTensor* embeddings_ = nullptr; 129 130 // NOTE: This interpreter is used in a read-only way (as a storage for the 131 // model params), thus is still thread-safe. 132 std::unique_ptr<tflite::Interpreter> interpreter_; 133 }; 134 135 } // namespace libtextclassifier2 136 137 #endif // LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ 138