1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
18 #define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
19 
20 #include <string>
21 #include <vector>
22 
23 #include "common/embedding-feature-extractor.h"
24 #include "common/feature-extractor.h"
25 #include "common/task-context.h"
26 #include "common/workspace.h"
27 #include "lang_id/light-sentence-features.h"
28 #include "lang_id/light-sentence.h"
29 #include "util/base/macros.h"
30 
31 namespace libtextclassifier {
32 namespace nlp_core {
33 namespace lang_id {
34 
35 // Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
36 class LangIdEmbeddingFeatureExtractor
37     : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> {
38  public:
LangIdEmbeddingFeatureExtractor()39   LangIdEmbeddingFeatureExtractor() {}
ArgPrefix()40   const std::string ArgPrefix() const override { return "language_identifier"; }
41 
42   TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor);
43 };
44 
45 // Handles sentence -> numeric_features and numeric_prediction -> language
46 // conversions.
47 class LangIdBrainInterface {
48  public:
LangIdBrainInterface()49   LangIdBrainInterface() {}
50 
51   // Initializes resources and parameters.
Init(TaskContext * context)52   bool Init(TaskContext *context) {
53     if (!feature_extractor_.Init(context)) {
54       return false;
55     }
56     feature_extractor_.RequestWorkspaces(&workspace_registry_);
57     return true;
58   }
59 
60   // Extract features from sentence.  On return, FeatureVector features[i]
61   // contains the features for the embedding space #i.
GetFeatures(LightSentence * sentence,std::vector<FeatureVector> * features)62   void GetFeatures(LightSentence *sentence,
63                    std::vector<FeatureVector> *features) const {
64     WorkspaceSet workspace;
65     workspace.Reset(workspace_registry_);
66     feature_extractor_.Preprocess(&workspace, sentence);
67     return feature_extractor_.ExtractFeatures(workspace, *sentence, features);
68   }
69 
NumEmbeddings()70   int NumEmbeddings() const {
71     return feature_extractor_.NumEmbeddings();
72   }
73 
74  private:
75   // Typed feature extractor for embeddings.
76   LangIdEmbeddingFeatureExtractor feature_extractor_;
77 
78   // The registry of shared workspaces in the feature extractor.
79   WorkspaceRegistry workspace_registry_;
80 
81   TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface);
82 };
83 
84 }  // namespace lang_id
85 }  // namespace nlp_core
86 }  // namespace libtextclassifier
87 
88 #endif  // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
89