1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "lang_id/common/fel/feature-extractor.h"
25 #include "lang_id/common/fel/task-context.h"
26 #include "lang_id/common/fel/workspace.h"
27 #include "lang_id/common/lite_base/attributes.h"
28 
29 namespace libtextclassifier3 {
30 namespace mobile {
31 
32 // An EmbeddingFeatureExtractor manages the extraction of features for
33 // embedding-based models. It wraps a sequence of underlying classes of feature
34 // extractors, along with associated predicate maps. Each class of feature
35 // extractors is associated with a name, e.g., "words", "labels", "tags".
36 //
37 // The class is split between a generic abstract version,
38 // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
39 // signature of the ExtractFeatures method) and a typed version.
40 //
41 // The predicate maps must be initialized before use: they can be loaded using
42 // Read() or updated via UpdateMapsForExample.
43 class GenericEmbeddingFeatureExtractor {
44  public:
45   // Constructs this GenericEmbeddingFeatureExtractor.
46   //
47   // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
48   // avoid name clashes.  See GetParamName().
GenericEmbeddingFeatureExtractor(const std::string & arg_prefix)49   explicit GenericEmbeddingFeatureExtractor(const std::string &arg_prefix)
50       : arg_prefix_(arg_prefix) {}
51 
~GenericEmbeddingFeatureExtractor()52   virtual ~GenericEmbeddingFeatureExtractor() {}
53 
54   // Sets/inits up predicate maps and embedding space names that are common for
55   // all embedding based feature extractors.
56   //
57   // Returns true on success, false otherwise.
58   SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
59   SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
60 
61   // Requests workspace for the underlying feature extractors. This is
62   // implemented in the typed class.
63   virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
64 
65   // Returns number of embedding spaces.
NumEmbeddings()66   int NumEmbeddings() const { return embedding_dims_.size(); }
67 
embedding_fml()68   const std::vector<std::string> &embedding_fml() const {
69     return embedding_fml_;
70   }
71 
72   // Get parameter name by concatenating the prefix and the original name.
GetParamName(const std::string & param_name)73   std::string GetParamName(const std::string &param_name) const {
74     std::string full_name = arg_prefix_;
75     full_name.push_back('_');
76     full_name.append(param_name);
77     return full_name;
78   }
79 
80  private:
81   // Prefix for TaskContext parameters.
82   const std::string arg_prefix_;
83 
84   // Embedding space names for parameter sharing.
85   std::vector<std::string> embedding_names_;
86 
87   // FML strings for each feature extractor.
88   std::vector<std::string> embedding_fml_;
89 
90   // Size of each of the embedding spaces (maximum predicate id).
91   std::vector<int> embedding_sizes_;
92 
93   // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
94   std::vector<int> embedding_dims_;
95 };
96 
97 // Templated, object-specific implementation of the
98 // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
99 // ARGS...> class that has the appropriate FeatureTraits() to ensure that
100 // locator type features work.
101 //
102 // Note: for backwards compatibility purposes, this always reads the FML spec
103 // from "<prefix>_features".
104 template <class EXTRACTOR, class OBJ, class... ARGS>
105 class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
106  public:
107   // Constructs this EmbeddingFeatureExtractor.
108   //
109   // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
110   // avoid name clashes.  See GetParamName().
EmbeddingFeatureExtractor(const std::string & arg_prefix)111   explicit EmbeddingFeatureExtractor(const std::string &arg_prefix)
112       : GenericEmbeddingFeatureExtractor(arg_prefix) {}
113 
114   // Sets up all predicate maps, feature extractors, and flags.
Setup(TaskContext * context)115   SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
116     if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
117       return false;
118     }
119     feature_extractors_.resize(embedding_fml().size());
120     for (int i = 0; i < embedding_fml().size(); ++i) {
121       feature_extractors_[i].reset(new EXTRACTOR());
122       if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
123       if (!feature_extractors_[i]->Setup(context)) return false;
124     }
125     return true;
126   }
127 
128   // Initializes resources needed by the feature extractors.
Init(TaskContext * context)129   SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
130     if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
131     for (auto &feature_extractor : feature_extractors_) {
132       if (!feature_extractor->Init(context)) return false;
133     }
134     return true;
135   }
136 
137   // Requests workspaces from the registry. Must be called after Init(), and
138   // before Preprocess().
RequestWorkspaces(WorkspaceRegistry * registry)139   void RequestWorkspaces(WorkspaceRegistry *registry) override {
140     for (auto &feature_extractor : feature_extractors_) {
141       feature_extractor->RequestWorkspaces(registry);
142     }
143   }
144 
145   // Must be called on the object one state for each sentence, before any
146   // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
Preprocess(WorkspaceSet * workspaces,OBJ * obj)147   void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
148     for (auto &feature_extractor : feature_extractors_) {
149       feature_extractor->Preprocess(workspaces, obj);
150     }
151   }
152 
153   // Extracts features using the extractors. Note that features must already
154   // be initialized to the correct number of feature extractors. No predicate
155   // mapping is applied.
ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & obj,ARGS...args,std::vector<FeatureVector> * features)156   void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
157                        ARGS... args,
158                        std::vector<FeatureVector> *features) const {
159     // DCHECK(features != nullptr);
160     // DCHECK_EQ(features->size(), feature_extractors_.size());
161     for (int i = 0; i < feature_extractors_.size(); ++i) {
162       (*features)[i].clear();
163       feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
164                                               &(*features)[i]);
165     }
166   }
167 
168  private:
169   // Templated feature extractor class.
170   std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
171 };
172 
173 }  // namespace mobile
174 }  // namespace nlp_saft
175 
176 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
177