1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_ 18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_ 19 20 #include <string> 21 #include <vector> 22 23 #include "lang_id/common/embedding-feature-extractor.h" 24 #include "lang_id/common/fel/feature-extractor.h" 25 #include "lang_id/common/fel/task-context.h" 26 #include "lang_id/common/fel/workspace.h" 27 #include "lang_id/common/lite_base/attributes.h" 28 29 namespace libtextclassifier3 { 30 namespace mobile { 31 32 template <class EXTRACTOR, class OBJ, class... ARGS> 33 class EmbeddingFeatureInterface { 34 public: 35 // Constructs this EmbeddingFeatureInterface. 36 // 37 // |arg_prefix| is a string prefix for the TaskContext parameters, passed to 38 // |the underlying EmbeddingFeatureExtractor. EmbeddingFeatureInterface(const string & arg_prefix)39 explicit EmbeddingFeatureInterface(const string &arg_prefix) 40 : feature_extractor_(arg_prefix) {} 41 42 // Sets up feature extractors and flags for processing (inference). SetupForProcessing(TaskContext * context)43 SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) { 44 return feature_extractor_.Setup(context); 45 } 46 47 // Initializes feature extractor resources for processing (inference) 48 // including requesting a workspace for caching extracted features. InitForProcessing(TaskContext * context)49 SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) { 50 if (!feature_extractor_.Init(context)) return false; 51 feature_extractor_.RequestWorkspaces(&workspace_registry_); 52 return true; 53 } 54 55 // Preprocesses *obj using the internal workspace registry. Preprocess(WorkspaceSet * workspace,OBJ * obj)56 void Preprocess(WorkspaceSet *workspace, OBJ *obj) const { 57 workspace->Reset(workspace_registry_); 58 feature_extractor_.Preprocess(workspace, obj); 59 } 60 61 // Extract features from |obj|. On return, FeatureVector features[i] 62 // contains the features for the embedding space #i. 63 // 64 // This function uses the precomputed info from |workspace|. Usage pattern: 65 // 66 // EmbeddingFeatureInterface<...> feature_interface; 67 // ... 68 // OBJ obj; 69 // WorkspaceSet workspace; 70 // feature_interface.Preprocess(&workspace, &obj); 71 // 72 // // For the same obj, but with different args: 73 // std::vector<FeatureVector> features; 74 // feature_interface.GetFeatures(obj, args, workspace, &features); 75 // 76 // This pattern is useful (more efficient) if you can pre-compute some info 77 // for the entire |obj|, which is reused by the feature extraction performed 78 // for different args. If that is not the case, you can use the simpler 79 // version GetFeaturesNoCaching below. GetFeatures(const OBJ & obj,ARGS...args,const WorkspaceSet & workspace,std::vector<FeatureVector> * features)80 void GetFeatures(const OBJ &obj, ARGS... args, const WorkspaceSet &workspace, 81 std::vector<FeatureVector> *features) const { 82 feature_extractor_.ExtractFeatures(workspace, obj, args..., features); 83 } 84 85 // Simpler version of GetFeatures(), for cases when there is no opportunity to 86 // reuse computation between feature extractions for the same |obj|, but with 87 // different |args|. Returns the extracted features. For more info, see the 88 // doc for GetFeatures(). GetFeaturesNoCaching(OBJ * obj,ARGS...args)89 std::vector<FeatureVector> GetFeaturesNoCaching(OBJ *obj, 90 ARGS... args) const { 91 // Technically, we still use a workspace, because 92 // feature_extractor_.ExtractFeatures requires one. But there is no real 93 // caching here, as we start from scratch for each call to ExtractFeatures. 94 WorkspaceSet workspace; 95 Preprocess(&workspace, obj); 96 std::vector<FeatureVector> features(NumEmbeddings()); 97 GetFeatures(*obj, args..., workspace, &features); 98 return features; 99 } 100 101 // Returns number of embedding spaces. NumEmbeddings()102 int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); } 103 104 private: 105 // Typed feature extractor for embeddings. 106 EmbeddingFeatureExtractor<EXTRACTOR, OBJ, ARGS...> feature_extractor_; 107 108 // The registry of shared workspaces in the feature extractor. 109 WorkspaceRegistry workspace_registry_; 110 }; 111 112 } // namespace mobile 113 } // namespace nlp_saft 114 115 #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_ 116