1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/lang-id.h"
18
19 #include <stdio.h>
20
21 #include <algorithm>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26
27 #include "lang_id/common/embedding-feature-interface.h"
28 #include "lang_id/common/embedding-network-params.h"
29 #include "lang_id/common/embedding-network.h"
30 #include "lang_id/common/fel/feature-extractor.h"
31 #include "lang_id/common/lite_base/logging.h"
32 #include "lang_id/common/lite_strings/numbers.h"
33 #include "lang_id/common/lite_strings/str-split.h"
34 #include "lang_id/common/lite_strings/stringpiece.h"
35 #include "lang_id/common/math/algorithm.h"
36 #include "lang_id/common/math/softmax.h"
37 #include "lang_id/custom-tokenizer.h"
38 #include "lang_id/features/light-sentence-features.h"
39 #include "lang_id/light-sentence.h"
40
41 namespace libtextclassifier3 {
42 namespace mobile {
43 namespace lang_id {
44
45 namespace {
46 // Default value for the confidence threshold. If the confidence of the top
47 // prediction is below this threshold, then FindLanguage() returns
48 // LangId::kUnknownLanguageCode. Note: this is just a default value; if the
49 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
50 // use that value instead. Note: for legacy reasons, our code and comments use
51 // the terms "confidence", "probability" and "reliability" equivalently.
52 static const float kDefaultConfidenceThreshold = 0.50f;
53 } // namespace
54
55 // Class that performs all work behind LangId.
56 class LangIdImpl {
57 public:
LangIdImpl(std::unique_ptr<ModelProvider> model_provider)58 explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
59 : model_provider_(std::move(model_provider)),
60 lang_id_brain_interface_("language_identifier") {
61 // Note: in the code below, we set valid_ to true only if all initialization
62 // steps completed successfully. Otherwise, we return early, leaving valid_
63 // to its default value false.
64 if (!model_provider_ || !model_provider_->is_valid()) {
65 SAFTM_LOG(ERROR) << "Invalid model provider";
66 return;
67 }
68
69 auto *nn_params = model_provider_->GetNnParams();
70 if (!nn_params) {
71 SAFTM_LOG(ERROR) << "No NN params";
72 return;
73 }
74 network_.reset(new EmbeddingNetwork(nn_params));
75
76 languages_ = model_provider_->GetLanguages();
77 if (languages_.empty()) {
78 SAFTM_LOG(ERROR) << "No known languages";
79 return;
80 }
81
82 TaskContext context = *model_provider_->GetTaskContext();
83 if (!Setup(&context)) {
84 SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
85 return;
86 }
87 if (!Init(&context)) {
88 SAFTM_LOG(ERROR) << "Unable to Init() LangId";
89 return;
90 }
91 valid_ = true;
92 }
93
FindLanguage(StringPiece text) const94 string FindLanguage(StringPiece text) const {
95 // NOTE: it would be wasteful to implement this method in terms of
96 // FindLanguages(). We just need the most likely language and its
97 // probability; no need to compute (and allocate) a vector of pairs for all
98 // languages, nor to compute probabilities for all non-top languages.
99 if (!is_valid()) {
100 return LangId::kUnknownLanguageCode;
101 }
102
103 std::vector<float> scores;
104 ComputeScores(text, &scores);
105
106 int prediction_id = GetArgMax(scores);
107 const string language = GetLanguageForSoftmaxLabel(prediction_id);
108 float probability = ComputeSoftmaxProbability(scores, prediction_id);
109 SAFTM_DLOG(INFO) << "Predicted " << language
110 << " with prob: " << probability << " for \"" << text
111 << "\"";
112
113 // Find confidence threshold for language.
114 float threshold = default_threshold_;
115 auto it = per_lang_thresholds_.find(language);
116 if (it != per_lang_thresholds_.end()) {
117 threshold = it->second;
118 }
119 if (probability < threshold) {
120 SAFTM_DLOG(INFO) << " below threshold => "
121 << LangId::kUnknownLanguageCode;
122 return LangId::kUnknownLanguageCode;
123 }
124 return language;
125 }
126
FindLanguages(StringPiece text,LangIdResult * result) const127 void FindLanguages(StringPiece text, LangIdResult *result) const {
128 if (result == nullptr) return;
129
130 result->predictions.clear();
131 if (!is_valid()) {
132 result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
133 return;
134 }
135
136 std::vector<float> scores;
137 ComputeScores(text, &scores);
138
139 // Compute and sort softmax in descending order by probability and convert
140 // IDs to language code strings. When probabilities are equal, we sort by
141 // language code string in ascending order.
142 std::vector<float> softmax = ComputeSoftmax(scores);
143
144 for (int i = 0; i < softmax.size(); ++i) {
145 result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
146 softmax[i]);
147 }
148
149 // Sort the resulting language predictions by probability in descending
150 // order.
151 std::sort(result->predictions.begin(), result->predictions.end(),
152 [](const std::pair<string, float> &a,
153 const std::pair<string, float> &b) {
154 if (a.second == b.second) {
155 return a.first.compare(b.first) < 0;
156 } else {
157 return a.second > b.second;
158 }
159 });
160 }
161
is_valid() const162 bool is_valid() const { return valid_; }
163
GetModelVersion() const164 int GetModelVersion() const { return model_version_; }
165
166 // Returns a property stored in the model file.
167 template <typename T, typename R>
GetProperty(const string & property,T default_value) const168 R GetProperty(const string &property, T default_value) const {
169 return model_provider_->GetTaskContext()->Get(property, default_value);
170 }
171
172 private:
Setup(TaskContext * context)173 bool Setup(TaskContext *context) {
174 tokenizer_.Setup(context);
175 if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
176 default_threshold_ =
177 context->Get("reliability_thresh", kDefaultConfidenceThreshold);
178
179 // Parse task parameter "per_lang_reliability_thresholds", fill
180 // per_lang_thresholds_.
181 const string thresholds_str =
182 context->Get("per_lang_reliability_thresholds", "");
183 std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
184 for (const auto &token : tokens) {
185 if (token.empty()) continue;
186 std::vector<StringPiece> parts = LiteStrSplit(token, '=');
187 float threshold = 0.0f;
188 if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
189 per_lang_thresholds_[string(parts[0])] = threshold;
190 } else {
191 SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
192 }
193 }
194 model_version_ = context->Get("model_version", model_version_);
195 return true;
196 }
197
Init(TaskContext * context)198 bool Init(TaskContext *context) {
199 return lang_id_brain_interface_.InitForProcessing(context);
200 }
201
202 // Extracts features for |text|, runs them through the feed-forward neural
203 // network, and computes the output scores (activations from the last layer).
204 // These scores can be used to compute the softmax probabilities for our
205 // labels (in this case, the languages).
ComputeScores(StringPiece text,std::vector<float> * scores) const206 void ComputeScores(StringPiece text, std::vector<float> *scores) const {
207 // Create a Sentence storing the input text.
208 LightSentence sentence;
209 tokenizer_.Tokenize(text, &sentence);
210
211 std::vector<FeatureVector> features =
212 lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
213
214 // Run feed-forward neural network to compute scores.
215 network_->ComputeFinalScores(features, scores);
216 }
217
218 // Returns language code for a softmax label. See comments for languages_
219 // field. If label is out of range, returns LangId::kUnknownLanguageCode.
GetLanguageForSoftmaxLabel(int label) const220 string GetLanguageForSoftmaxLabel(int label) const {
221 if ((label >= 0) && (label < languages_.size())) {
222 return languages_[label];
223 } else {
224 SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
225 << languages_.size() << ")";
226 return LangId::kUnknownLanguageCode;
227 }
228 }
229
230 std::unique_ptr<ModelProvider> model_provider_;
231
232 TokenizerForLangId tokenizer_;
233
234 EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
235 lang_id_brain_interface_;
236
237 // Neural network to use for scoring.
238 std::unique_ptr<EmbeddingNetwork> network_;
239
240 // True if this object is ready to perform language predictions.
241 bool valid_ = false;
242
243 // Only predictions with a probability (confidence) above this threshold are
244 // reported. Otherwise, we report LangId::kUnknownLanguageCode.
245 float default_threshold_ = kDefaultConfidenceThreshold;
246
247 std::unordered_map<string, float> per_lang_thresholds_;
248
249 // Recognized languages: softmax label i means languages_[i] (something like
250 // "en", "fr", "ru", etc).
251 std::vector<string> languages_;
252
253 // Version of the model used by this LangIdImpl object. Zero means that the
254 // model version could not be determined.
255 int model_version_ = 0;
256 };
257
258 const char LangId::kUnknownLanguageCode[] = "und";
259
LangId(std::unique_ptr<ModelProvider> model_provider)260 LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
261 : pimpl_(new LangIdImpl(std::move(model_provider))) {}
262
263 LangId::~LangId() = default;
264
FindLanguage(const char * data,size_t num_bytes) const265 string LangId::FindLanguage(const char *data, size_t num_bytes) const {
266 StringPiece text(data, num_bytes);
267 return pimpl_->FindLanguage(text);
268 }
269
FindLanguages(const char * data,size_t num_bytes,LangIdResult * result) const270 void LangId::FindLanguages(const char *data, size_t num_bytes,
271 LangIdResult *result) const {
272 SAFTM_DCHECK(result) << "LangIdResult must not be null.";
273 StringPiece text(data, num_bytes);
274 pimpl_->FindLanguages(text, result);
275 }
276
is_valid() const277 bool LangId::is_valid() const { return pimpl_->is_valid(); }
278
GetModelVersion() const279 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
280
GetFloatProperty(const string & property,float default_value) const281 float LangId::GetFloatProperty(const string &property,
282 float default_value) const {
283 return pimpl_->GetProperty<float, float>(property, default_value);
284 }
285
286 } // namespace lang_id
287 } // namespace mobile
288 } // namespace nlp_saft
289