1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains classes that can execute different models/parts of a model.
18 
19 #ifndef LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_
20 #define LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_
21 
22 #include <memory>
23 
24 #include "tensor-view.h"
25 #include "types.h"
26 #include "util/base/logging.h"
27 #include "tensorflow/contrib/lite/interpreter.h"
28 #include "tensorflow/contrib/lite/kernels/register.h"
29 #include "tensorflow/contrib/lite/model.h"
30 
31 namespace libtextclassifier2 {
32 
33 namespace internal {
34 bool FromModelSpec(const tflite::Model* model_spec,
35                    std::unique_ptr<const tflite::FlatBufferModel>* model);
36 }  // namespace internal
37 
38 // A helper function that given indices of feature and logits tensor, feature
39 // values computes the logits using given interpreter.
40 TensorView<float> ComputeLogitsHelper(const int input_index_features,
41                                       const int output_index_logits,
42                                       const TensorView<float>& features,
43                                       tflite::Interpreter* interpreter);
44 
45 // Executor for the text selection prediction and classification models.
46 class ModelExecutor {
47  public:
Instance(const flatbuffers::Vector<uint8_t> * model_spec_buffer)48   static std::unique_ptr<const ModelExecutor> Instance(
49       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
50     const tflite::Model* model =
51         flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
52     flatbuffers::Verifier verifier(model_spec_buffer->data(),
53                                    model_spec_buffer->Length());
54     if (!model->Verify(verifier)) {
55       return nullptr;
56     }
57     return Instance(model);
58   }
59 
Instance(const tflite::Model * model_spec)60   static std::unique_ptr<const ModelExecutor> Instance(
61       const tflite::Model* model_spec) {
62     std::unique_ptr<const tflite::FlatBufferModel> model;
63     if (!internal::FromModelSpec(model_spec, &model)) {
64       return nullptr;
65     }
66     return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
67   }
68 
69   // Creates an Interpreter for the model that serves as a scratch-pad for the
70   // inference. The Interpreter is NOT thread-safe.
71   std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
72 
ComputeLogits(const TensorView<float> & features,tflite::Interpreter * interpreter)73   TensorView<float> ComputeLogits(const TensorView<float>& features,
74                                   tflite::Interpreter* interpreter) const {
75     return ComputeLogitsHelper(kInputIndexFeatures, kOutputIndexLogits,
76                                features, interpreter);
77   }
78 
79  protected:
ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)80   explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
81       : model_(std::move(model)) {}
82 
83   static const int kInputIndexFeatures = 0;
84   static const int kOutputIndexLogits = 0;
85 
86   std::unique_ptr<const tflite::FlatBufferModel> model_;
87   tflite::ops::builtin::BuiltinOpResolver builtins_;
88 };
89 
90 // Executor for embedding sparse features into a dense vector.
91 class EmbeddingExecutor {
92  public:
~EmbeddingExecutor()93   virtual ~EmbeddingExecutor() {}
94 
95   // Embeds the sparse_features into a dense embedding and adds (+) it
96   // element-wise to the dest vector.
97   virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
98                             int dest_size) const = 0;
99 
100   // Returns true when the model is ready to be used, false otherwise.
IsReady()101   virtual bool IsReady() const { return true; }
102 };
103 
104 class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
105  public:
106   static std::unique_ptr<TFLiteEmbeddingExecutor> Instance(
107       const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
108       int quantization_bits);
109 
110   bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
111                     int dest_size) const override;
112 
113  protected:
114   explicit TFLiteEmbeddingExecutor(
115       std::unique_ptr<const tflite::FlatBufferModel> model,
116       int quantization_bits, int num_buckets, int bytes_per_embedding,
117       int output_embedding_size, const TfLiteTensor* scales,
118       const TfLiteTensor* embeddings,
119       std::unique_ptr<tflite::Interpreter> interpreter);
120 
121   std::unique_ptr<const tflite::FlatBufferModel> model_;
122 
123   int quantization_bits_;
124   int num_buckets_ = -1;
125   int bytes_per_embedding_ = -1;
126   int output_embedding_size_ = -1;
127   const TfLiteTensor* scales_ = nullptr;
128   const TfLiteTensor* embeddings_ = nullptr;
129 
130   // NOTE: This interpreter is used in a read-only way (as a storage for the
131   // model params), thus is still thread-safe.
132   std::unique_ptr<tflite::Interpreter> interpreter_;
133 };
134 
135 }  // namespace libtextclassifier2
136 
137 #endif  // LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_
138