1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Model parameter loading.
18 
19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
20 #define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
21 
22 #include "common/embedding-network.h"
23 #include "common/memory_image/embedding-network-params-from-image.h"
24 #include "smartselect/text-classification-model.pb.h"
25 
26 namespace libtextclassifier {
27 
28 class EmbeddingParams : public nlp_core::EmbeddingNetworkParamsFromImage {
29  public:
EmbeddingParams(const void * start,uint64 num_bytes,int context_size)30   EmbeddingParams(const void* start, uint64 num_bytes, int context_size)
31       : EmbeddingNetworkParamsFromImage(start, num_bytes),
32         context_size_(context_size) {}
33 
embeddings_size()34   int embeddings_size() const override { return context_size_ * 2 + 1; }
35 
embedding_num_features_size()36   int embedding_num_features_size() const override {
37     return context_size_ * 2 + 1;
38   }
39 
embedding_num_features(int i)40   int embedding_num_features(int i) const override { return 1; }
41 
embeddings_num_rows(int i)42   int embeddings_num_rows(int i) const override {
43     return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0);
44   };
45 
embeddings_num_cols(int i)46   int embeddings_num_cols(int i) const override {
47     return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0);
48   };
49 
embeddings_weights(int i)50   const void* embeddings_weights(int i) const override {
51     return EmbeddingNetworkParamsFromImage::embeddings_weights(0);
52   };
53 
embeddings_quant_type(int i)54   nlp_core::QuantizationType embeddings_quant_type(int i) const override {
55     return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0);
56   }
57 
embeddings_quant_scales(int i)58   const nlp_core::float16* embeddings_quant_scales(int i) const override {
59     return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0);
60   }
61 
62  private:
63   int context_size_;
64 };
65 
66 // Loads and holds the parameters of the inference network.
67 //
68 // This class overrides a couple of methods of EmbeddingNetworkParamsFromImage
69 // because we only have one embedding matrix for all positions of context,
70 // whereas the original class would have a separate one for each.
71 class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage {
72  public:
GetFeatureProcessorOptions()73   const FeatureProcessorOptions& GetFeatureProcessorOptions() const {
74     return feature_processor_options_;
75   }
76 
GetSelectionModelOptions()77   const SelectionModelOptions& GetSelectionModelOptions() const {
78     return selection_options_;
79   }
80 
GetSharingModelOptions()81   const SharingModelOptions& GetSharingModelOptions() const {
82     return sharing_options_;
83   }
84 
GetEmbeddingParams()85   std::shared_ptr<EmbeddingParams> GetEmbeddingParams() const {
86     return embedding_params_;
87   }
88 
89  protected:
embeddings_size()90   int embeddings_size() const override {
91     return embedding_params_->embeddings_size();
92   }
93 
embedding_num_features_size()94   int embedding_num_features_size() const override {
95     return embedding_params_->embedding_num_features_size();
96   }
97 
embedding_num_features(int i)98   int embedding_num_features(int i) const override {
99     return embedding_params_->embedding_num_features(i);
100   }
101 
embeddings_num_rows(int i)102   int embeddings_num_rows(int i) const override {
103     return embedding_params_->embeddings_num_rows(i);
104   };
105 
embeddings_num_cols(int i)106   int embeddings_num_cols(int i) const override {
107     return embedding_params_->embeddings_num_cols(i);
108   };
109 
embeddings_weights(int i)110   const void* embeddings_weights(int i) const override {
111     return embedding_params_->embeddings_weights(i);
112   };
113 
embeddings_quant_type(int i)114   nlp_core::QuantizationType embeddings_quant_type(int i) const override {
115     return embedding_params_->embeddings_quant_type(i);
116   }
117 
embeddings_quant_scales(int i)118   const nlp_core::float16* embeddings_quant_scales(int i) const override {
119     return embedding_params_->embeddings_quant_scales(i);
120   }
121 
122  private:
123   friend ModelParams* ModelParamsBuilder(
124       const void* start, uint64 num_bytes,
125       std::shared_ptr<EmbeddingParams> external_embedding_params);
126 
ModelParams(const void * start,uint64 num_bytes,std::shared_ptr<EmbeddingParams> embedding_params,const SelectionModelOptions & selection_options,const SharingModelOptions & sharing_options,const FeatureProcessorOptions & feature_processor_options)127   ModelParams(const void* start, uint64 num_bytes,
128               std::shared_ptr<EmbeddingParams> embedding_params,
129               const SelectionModelOptions& selection_options,
130               const SharingModelOptions& sharing_options,
131               const FeatureProcessorOptions& feature_processor_options)
132       : EmbeddingNetworkParamsFromImage(start, num_bytes),
133         selection_options_(selection_options),
134         sharing_options_(sharing_options),
135         feature_processor_options_(feature_processor_options),
136         context_size_(feature_processor_options_.context_size()),
137         embedding_params_(std::move(embedding_params)) {}
138 
139   SelectionModelOptions selection_options_;
140   SharingModelOptions sharing_options_;
141   FeatureProcessorOptions feature_processor_options_;
142   int context_size_;
143   std::shared_ptr<EmbeddingParams> embedding_params_;
144 };
145 
146 ModelParams* ModelParamsBuilder(
147     const void* start, uint64 num_bytes,
148     std::shared_ptr<EmbeddingParams> external_embedding_params);
149 
150 }  // namespace libtextclassifier
151 
152 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
153