1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/lang-id.h"
18 
19 #include <stdio.h>
20 
21 #include <algorithm>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "lang_id/common/embedding-feature-interface.h"
28 #include "lang_id/common/embedding-network-params.h"
29 #include "lang_id/common/embedding-network.h"
30 #include "lang_id/common/fel/feature-extractor.h"
31 #include "lang_id/common/lite_base/logging.h"
32 #include "lang_id/common/lite_strings/numbers.h"
33 #include "lang_id/common/lite_strings/str-split.h"
34 #include "lang_id/common/lite_strings/stringpiece.h"
35 #include "lang_id/common/math/algorithm.h"
36 #include "lang_id/common/math/softmax.h"
37 #include "lang_id/custom-tokenizer.h"
38 #include "lang_id/features/light-sentence-features.h"
39 #include "lang_id/light-sentence.h"
40 
41 namespace libtextclassifier3 {
42 namespace mobile {
43 namespace lang_id {
44 
45 namespace {
46 // Default value for the confidence threshold.  If the confidence of the top
47 // prediction is below this threshold, then FindLanguage() returns
48 // LangId::kUnknownLanguageCode.  Note: this is just a default value; if the
49 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
50 // use that value instead.  Note: for legacy reasons, our code and comments use
51 // the terms "confidence", "probability" and "reliability" equivalently.
52 static const float kDefaultConfidenceThreshold = 0.50f;
53 }  // namespace
54 
55 // Class that performs all work behind LangId.
56 class LangIdImpl {
57  public:
LangIdImpl(std::unique_ptr<ModelProvider> model_provider)58   explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
59       : model_provider_(std::move(model_provider)),
60         lang_id_brain_interface_("language_identifier") {
61     // Note: in the code below, we set valid_ to true only if all initialization
62     // steps completed successfully.  Otherwise, we return early, leaving valid_
63     // to its default value false.
64     if (!model_provider_ || !model_provider_->is_valid()) {
65       SAFTM_LOG(ERROR) << "Invalid model provider";
66       return;
67     }
68 
69     auto *nn_params = model_provider_->GetNnParams();
70     if (!nn_params) {
71       SAFTM_LOG(ERROR) << "No NN params";
72       return;
73     }
74     network_.reset(new EmbeddingNetwork(nn_params));
75 
76     languages_ = model_provider_->GetLanguages();
77     if (languages_.empty()) {
78       SAFTM_LOG(ERROR) << "No known languages";
79       return;
80     }
81 
82     TaskContext context = *model_provider_->GetTaskContext();
83     if (!Setup(&context)) {
84       SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
85       return;
86     }
87     if (!Init(&context)) {
88       SAFTM_LOG(ERROR) << "Unable to Init() LangId";
89       return;
90     }
91     valid_ = true;
92   }
93 
FindLanguage(StringPiece text) const94   string FindLanguage(StringPiece text) const {
95     // NOTE: it would be wasteful to implement this method in terms of
96     // FindLanguages().  We just need the most likely language and its
97     // probability; no need to compute (and allocate) a vector of pairs for all
98     // languages, nor to compute probabilities for all non-top languages.
99     if (!is_valid()) {
100       return LangId::kUnknownLanguageCode;
101     }
102 
103     std::vector<float> scores;
104     ComputeScores(text, &scores);
105 
106     int prediction_id = GetArgMax(scores);
107     const string language = GetLanguageForSoftmaxLabel(prediction_id);
108     float probability = ComputeSoftmaxProbability(scores, prediction_id);
109     SAFTM_DLOG(INFO) << "Predicted " << language
110                      << " with prob: " << probability << " for \"" << text
111                      << "\"";
112 
113     // Find confidence threshold for language.
114     float threshold = default_threshold_;
115     auto it = per_lang_thresholds_.find(language);
116     if (it != per_lang_thresholds_.end()) {
117       threshold = it->second;
118     }
119     if (probability < threshold) {
120       SAFTM_DLOG(INFO) << "  below threshold => "
121                        << LangId::kUnknownLanguageCode;
122       return LangId::kUnknownLanguageCode;
123     }
124     return language;
125   }
126 
FindLanguages(StringPiece text,LangIdResult * result) const127   void FindLanguages(StringPiece text, LangIdResult *result) const {
128     if (result == nullptr) return;
129 
130     result->predictions.clear();
131     if (!is_valid()) {
132       result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
133       return;
134     }
135 
136     std::vector<float> scores;
137     ComputeScores(text, &scores);
138 
139     // Compute and sort softmax in descending order by probability and convert
140     // IDs to language code strings.  When probabilities are equal, we sort by
141     // language code string in ascending order.
142     std::vector<float> softmax = ComputeSoftmax(scores);
143 
144     for (int i = 0; i < softmax.size(); ++i) {
145       result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
146                                        softmax[i]);
147     }
148 
149     // Sort the resulting language predictions by probability in descending
150     // order.
151     std::sort(result->predictions.begin(), result->predictions.end(),
152               [](const std::pair<string, float> &a,
153                  const std::pair<string, float> &b) {
154                 if (a.second == b.second) {
155                   return a.first.compare(b.first) < 0;
156                 } else {
157                   return a.second > b.second;
158                 }
159               });
160   }
161 
is_valid() const162   bool is_valid() const { return valid_; }
163 
GetModelVersion() const164   int GetModelVersion() const { return model_version_; }
165 
166   // Returns a property stored in the model file.
167   template <typename T, typename R>
GetProperty(const string & property,T default_value) const168   R GetProperty(const string &property, T default_value) const {
169     return model_provider_->GetTaskContext()->Get(property, default_value);
170   }
171 
172  private:
Setup(TaskContext * context)173   bool Setup(TaskContext *context) {
174     tokenizer_.Setup(context);
175     if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
176     default_threshold_ =
177         context->Get("reliability_thresh", kDefaultConfidenceThreshold);
178 
179     // Parse task parameter "per_lang_reliability_thresholds", fill
180     // per_lang_thresholds_.
181     const string thresholds_str =
182         context->Get("per_lang_reliability_thresholds", "");
183     std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
184     for (const auto &token : tokens) {
185       if (token.empty()) continue;
186       std::vector<StringPiece> parts = LiteStrSplit(token, '=');
187       float threshold = 0.0f;
188       if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
189         per_lang_thresholds_[string(parts[0])] = threshold;
190       } else {
191         SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
192       }
193     }
194     model_version_ = context->Get("model_version", model_version_);
195     return true;
196   }
197 
Init(TaskContext * context)198   bool Init(TaskContext *context) {
199     return lang_id_brain_interface_.InitForProcessing(context);
200   }
201 
202   // Extracts features for |text|, runs them through the feed-forward neural
203   // network, and computes the output scores (activations from the last layer).
204   // These scores can be used to compute the softmax probabilities for our
205   // labels (in this case, the languages).
ComputeScores(StringPiece text,std::vector<float> * scores) const206   void ComputeScores(StringPiece text, std::vector<float> *scores) const {
207     // Create a Sentence storing the input text.
208     LightSentence sentence;
209     tokenizer_.Tokenize(text, &sentence);
210 
211     std::vector<FeatureVector> features =
212         lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
213 
214     // Run feed-forward neural network to compute scores.
215     network_->ComputeFinalScores(features, scores);
216   }
217 
218   // Returns language code for a softmax label.  See comments for languages_
219   // field.  If label is out of range, returns LangId::kUnknownLanguageCode.
GetLanguageForSoftmaxLabel(int label) const220   string GetLanguageForSoftmaxLabel(int label) const {
221     if ((label >= 0) && (label < languages_.size())) {
222       return languages_[label];
223     } else {
224       SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
225                        << languages_.size() << ")";
226       return LangId::kUnknownLanguageCode;
227     }
228   }
229 
230   std::unique_ptr<ModelProvider> model_provider_;
231 
232   TokenizerForLangId tokenizer_;
233 
234   EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
235       lang_id_brain_interface_;
236 
237   // Neural network to use for scoring.
238   std::unique_ptr<EmbeddingNetwork> network_;
239 
240   // True if this object is ready to perform language predictions.
241   bool valid_ = false;
242 
243   // Only predictions with a probability (confidence) above this threshold are
244   // reported.  Otherwise, we report LangId::kUnknownLanguageCode.
245   float default_threshold_ = kDefaultConfidenceThreshold;
246 
247   std::unordered_map<string, float> per_lang_thresholds_;
248 
249   // Recognized languages: softmax label i means languages_[i] (something like
250   // "en", "fr", "ru", etc).
251   std::vector<string> languages_;
252 
253   // Version of the model used by this LangIdImpl object.  Zero means that the
254   // model version could not be determined.
255   int model_version_ = 0;
256 };
257 
258 const char LangId::kUnknownLanguageCode[] = "und";
259 
LangId(std::unique_ptr<ModelProvider> model_provider)260 LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
261     : pimpl_(new LangIdImpl(std::move(model_provider))) {}
262 
263 LangId::~LangId() = default;
264 
FindLanguage(const char * data,size_t num_bytes) const265 string LangId::FindLanguage(const char *data, size_t num_bytes) const {
266   StringPiece text(data, num_bytes);
267   return pimpl_->FindLanguage(text);
268 }
269 
FindLanguages(const char * data,size_t num_bytes,LangIdResult * result) const270 void LangId::FindLanguages(const char *data, size_t num_bytes,
271                            LangIdResult *result) const {
272   SAFTM_DCHECK(result) << "LangIdResult must not be null.";
273   StringPiece text(data, num_bytes);
274   pimpl_->FindLanguages(text, result);
275 }
276 
is_valid() const277 bool LangId::is_valid() const { return pimpl_->is_valid(); }
278 
GetModelVersion() const279 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
280 
GetFloatProperty(const string & property,float default_value) const281 float LangId::GetFloatProperty(const string &property,
282                                float default_value) const {
283   return pimpl_->GetProperty<float, float>(property, default_value);
284 }
285 
286 }  // namespace lang_id
287 }  // namespace mobile
288 }  // namespace nlp_saft
289