1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "lang_id/common/embedding-network-params.h"
26 #include "lang_id/common/flatbuffers/embedding-network_generated.h"
27 #include "lang_id/common/lite_base/float16.h"
28 #include "lang_id/common/lite_base/logging.h"
29 #include "lang_id/common/lite_strings/stringpiece.h"
30 
31 namespace libtextclassifier3 {
32 namespace mobile {
33 
34 // EmbeddingNetworkParams implementation backed by a flatbuffer.
35 //
36 // For info on our flatbuffer schema, see embedding-network.fbs.
37 class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams {
38  public:
39   // Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the
40   // flatbuffer from |bytes|.
41   //
42   // IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime
43   // of this EmbeddingNetworkParamsFromFlatbuffer instance.  To avoid overhead,
44   // this constructor does not copy |bytes|.
45   //
46   // IMPORTANT #2: immediately after this constructor returns, we suggest you
47   // call is_valid() on the newly-constructed object and do not call any other
48   // method if the answer is negative (false).
49   explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes);
50 
UpdateTaskContextParameters(mobile::TaskContext * task_context)51   bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override {
52     // This class does not provide access to the overall TaskContext.  It
53     // provides only parameters for the Neurosis neural network.
54     SAFTM_LOG(DFATAL) << "Not supported";
55     return false;
56   }
57 
is_valid()58   bool is_valid() const override { return valid_; }
59 
embeddings_size()60   int embeddings_size() const override { return SafeGetNumInputChunks(); }
61 
embeddings_num_rows(int i)62   int embeddings_num_rows(int i) const override {
63     const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
64     return SafeGetNumRows(matrix);
65   }
66 
embeddings_num_cols(int i)67   int embeddings_num_cols(int i) const override {
68     const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
69     return SafeGetNumCols(matrix);
70   }
71 
embeddings_weights(int i)72   const void *embeddings_weights(int i) const override {
73     const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
74     return SafeGetValuesOfMatrix(matrix);
75   }
76 
embeddings_quant_type(int i)77   QuantizationType embeddings_quant_type(int i) const override {
78     const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
79     return SafeGetQuantizationType(matrix);
80   }
81 
embeddings_quant_scales(int i)82   const float16 *embeddings_quant_scales(int i) const override {
83     const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
84     return SafeGetScales(matrix);
85   }
86 
hidden_size()87   int hidden_size() const override {
88     // -1 because last layer is always the softmax layer.
89     return std::max(SafeGetNumLayers() - 1, 0);
90   }
91 
hidden_num_rows(int i)92   int hidden_num_rows(int i) const override {
93     const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
94     return SafeGetNumRows(weights);
95   }
96 
hidden_num_cols(int i)97   int hidden_num_cols(int i) const override {
98     const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
99     return SafeGetNumCols(weights);
100   }
101 
hidden_weights_quant_type(int i)102   QuantizationType hidden_weights_quant_type(int i) const override {
103     const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
104     return SafeGetQuantizationType(weights);
105   }
106 
hidden_weights(int i)107   const void *hidden_weights(int i) const override {
108     const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
109     return SafeGetValuesOfMatrix(weights);
110   }
111 
hidden_bias_size()112   int hidden_bias_size() const override { return hidden_size(); }
113 
hidden_bias_num_rows(int i)114   int hidden_bias_num_rows(int i) const override {
115     const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
116     return SafeGetNumRows(bias);
117   }
118 
hidden_bias_num_cols(int i)119   int hidden_bias_num_cols(int i) const override {
120     const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
121     return SafeGetNumCols(bias);
122   }
123 
hidden_bias_weights(int i)124   const void *hidden_bias_weights(int i) const override {
125     const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
126     return SafeGetValues(bias);
127   }
128 
softmax_size()129   int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; }
130 
softmax_num_rows(int i)131   int softmax_num_rows(int i) const override {
132     const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
133     return SafeGetNumRows(weights);
134   }
135 
softmax_num_cols(int i)136   int softmax_num_cols(int i) const override {
137     const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
138     return SafeGetNumCols(weights);
139   }
140 
softmax_weights_quant_type(int i)141   QuantizationType softmax_weights_quant_type(int i) const override {
142     const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
143     return SafeGetQuantizationType(weights);
144   }
145 
softmax_weights(int i)146   const void *softmax_weights(int i) const override {
147     const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
148     return SafeGetValuesOfMatrix(weights);
149   }
150 
softmax_bias_size()151   int softmax_bias_size() const override { return softmax_size(); }
152 
softmax_bias_num_rows(int i)153   int softmax_bias_num_rows(int i) const override {
154     const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
155     return SafeGetNumRows(bias);
156   }
157 
softmax_bias_num_cols(int i)158   int softmax_bias_num_cols(int i) const override {
159     const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
160     return SafeGetNumCols(bias);
161   }
162 
softmax_bias_weights(int i)163   const void *softmax_bias_weights(int i) const override {
164     const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
165     return SafeGetValues(bias);
166   }
167 
embedding_num_features_size()168   int embedding_num_features_size() const override {
169     return SafeGetNumInputChunks();
170   }
171 
embedding_num_features(int i)172   int embedding_num_features(int i) const override {
173     if (!InRangeIndex(i, embedding_num_features_size(),
174                       "embedding num features")) {
175       return 0;
176     }
177     const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
178     if (input_chunk == nullptr) {
179       return 0;
180     }
181     return input_chunk->num_features();
182   }
183 
has_is_precomputed()184   bool has_is_precomputed() const override { return false; }
is_precomputed()185   bool is_precomputed() const override { return false; }
186 
187  private:
188   // Returns true if and only if index is in [0, limit).  info should be a
189   // pointer to a zero-terminated array of chars (ideally a literal string,
190   // e.g. "layer") indicating what the index refers to; info is used to make log
191   // messages more informative.
192   static bool InRangeIndex(int index, int limit, const char *info);
193 
194   // Returns network_->input_chunks()->size(), if all dereferences are safe
195   // (i.e., no nullptr); otherwise, returns 0.
196   int SafeGetNumInputChunks() const;
197 
198   // Returns network_->input_chunks()->Get(i), if all dereferences are safe
199   // (i.e., no nullptr) otherwise, returns nullptr.
200   const saft_fbs::InputChunk *SafeGetInputChunk(int i) const;
201 
202   // Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences
203   // are safe (i.e., no nullptr); otherwise, returns nullptr.
204   const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const;
205 
206   // Returns network_->layers()->size(), if all dereferences are safe (i.e., no
207   // nullptr); otherwise, returns 0.
208   int SafeGetNumLayers() const;
209 
210   // Returns network_->layers()->Get(i), if all dereferences are safe
211   // (i.e., no nullptr); otherwise, returns nullptr.
212   const saft_fbs::NeuralLayer *SafeGetLayer(int i) const;
213 
214   // Returns network_->layers()->Get(i)->weights(), if all dereferences are safe
215   // (i.e., no nullptr); otherwise, returns nullptr.
216   const saft_fbs::Matrix *SafeGetLayerWeights(int i) const;
217 
218   // Returns network_->layers()->Get(i)->bias(), if all dereferences are safe
219   // (i.e., no nullptr); otherwise, returns nullptr.
220   const saft_fbs::Matrix *SafeGetLayerBias(int i) const;
221 
SafeGetNumRows(const saft_fbs::Matrix * matrix)222   static int SafeGetNumRows(const saft_fbs::Matrix *matrix) {
223     return (matrix == nullptr) ? 0 : matrix->rows();
224   }
225 
SafeGetNumCols(const saft_fbs::Matrix * matrix)226   static int SafeGetNumCols(const saft_fbs::Matrix *matrix) {
227     return (matrix == nullptr) ? 0 : matrix->cols();
228   }
229 
230   // Returns matrix->values()->data() if all dereferences are safe (i.e., no
231   // nullptr); otherwise, returns nullptr.
232   static const float *SafeGetValues(const saft_fbs::Matrix *matrix);
233 
234   // Returns matrix->quantized_values()->data() if all dereferences are safe
235   // (i.e., no nullptr); otherwise, returns nullptr.
236   static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix);
237 
238   // Returns matrix->scales()->data() if all dereferences are safe (i.e., no
239   // nullptr); otherwise, returns nullptr.
240   static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix);
241 
242   // Returns network_->layers()->Get(last_index) with last_index =
243   // SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and
244   // there exists at least one layer; otherwise, returns nullptr.
245   const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const;
246 
SafeGetSoftmaxWeights()247   const saft_fbs::Matrix *SafeGetSoftmaxWeights() const {
248     const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
249     return (layer == nullptr) ? nullptr : layer->weights();
250   }
251 
SafeGetSoftmaxBias()252   const saft_fbs::Matrix *SafeGetSoftmaxBias() const {
253     const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
254     return (layer == nullptr) ? nullptr : layer->bias();
255   }
256 
257   // Returns the quantization type for |matrix|.  Returns NONE in case of
258   // problems (e.g., matrix is nullptr or unknown quantization type).
259   QuantizationType SafeGetQuantizationType(
260       const saft_fbs::Matrix *matrix) const;
261 
262   // Returns a pointer to the values (float, uint8, or float16, depending on
263   // quantization) from |matrix|, in row-major order.  Returns nullptr in case
264   // of a problem.
265   const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const;
266 
267   // Performs some validity checks.  E.g., check that dimensions of the network
268   // layers match.  Also checks that all pointers we return are inside the
269   // |bytes| passed to the constructor, such that client that reads from those
270   // pointers will not run into troubles.
271   bool ValidityChecking(StringPiece bytes) const;
272 
273   // True if these params are valid.  May be false if the original proto was
274   // corrupted.  We prefer to set this to false to CHECK-failing.
275   bool valid_ = false;
276 
277   // EmbeddingNetwork flatbuffer from the bytes passed as parameter to the
278   // constructor; see constructor doc.
279   const saft_fbs::EmbeddingNetwork *network_ = nullptr;
280 };
281 
282 }  // namespace mobile
283 }  // namespace nlp_saft
284 
285 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
286