1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/translate/translate.h"
18
19 #include <memory>
20
21 #include "annotator/collections.h"
22 #include "annotator/entity-data_generated.h"
23 #include "annotator/types.h"
24 #include "lang_id/lang-id-wrapper.h"
25 #include "utils/base/logging.h"
26 #include "utils/i18n/locale.h"
27 #include "utils/utf8/unicodetext.h"
28 #include "lang_id/lang-id.h"
29
30 namespace libtextclassifier3 {
31
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,const std::string & user_familiar_language_tags,ClassificationResult * classification_result) const32 bool TranslateAnnotator::ClassifyText(
33 const UnicodeText& context, CodepointSpan selection_indices,
34 const std::string& user_familiar_language_tags,
35 ClassificationResult* classification_result) const {
36 std::vector<TranslateAnnotator::LanguageConfidence> confidences;
37 if (options_->algorithm() ==
38 TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
39 if (options_->backoff_options() == nullptr) {
40 TC3_LOG(WARNING) << "No backoff options specified. Returning.";
41 return false;
42 }
43 confidences = BackoffDetectLanguages(context, selection_indices);
44 }
45
46 if (confidences.empty()) {
47 return false;
48 }
49
50 std::vector<Locale> user_familiar_languages;
51 if (!ParseLocales(user_familiar_language_tags, &user_familiar_languages)) {
52 TC3_LOG(WARNING) << "Couldn't parse the user-understood languages.";
53 return false;
54 }
55 if (user_familiar_languages.empty()) {
56 TC3_VLOG(INFO) << "user_familiar_languages is not set, not suggesting "
57 "translate action.";
58 return false;
59 }
60 bool user_can_understand_language_of_text = false;
61 for (const Locale& locale : user_familiar_languages) {
62 if (locale.Language() == confidences[0].language) {
63 user_can_understand_language_of_text = true;
64 break;
65 }
66 }
67
68 if (!user_can_understand_language_of_text) {
69 classification_result->collection = Collections::Translate();
70 classification_result->score = options_->score();
71 classification_result->priority_score = options_->priority_score();
72 classification_result->serialized_entity_data =
73 CreateSerializedEntityData(confidences);
74 return true;
75 }
76
77 return false;
78 }
79
CreateSerializedEntityData(const std::vector<TranslateAnnotator::LanguageConfidence> & confidences) const80 std::string TranslateAnnotator::CreateSerializedEntityData(
81 const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
82 const {
83 EntityDataT entity_data;
84 entity_data.translate.reset(new EntityData_::TranslateT());
85
86 for (const LanguageConfidence& confidence : confidences) {
87 EntityData_::Translate_::LanguagePredictionResultT*
88 language_prediction_result =
89 new EntityData_::Translate_::LanguagePredictionResultT();
90 language_prediction_result->language_tag = confidence.language;
91 language_prediction_result->confidence_score = confidence.confidence;
92 entity_data.translate->language_prediction_results.emplace_back(
93 language_prediction_result);
94 }
95 flatbuffers::FlatBufferBuilder builder;
96 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
97 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
98 builder.GetSize());
99 }
100
101 std::vector<TranslateAnnotator::LanguageConfidence>
BackoffDetectLanguages(const UnicodeText & context,CodepointSpan selection_indices) const102 TranslateAnnotator::BackoffDetectLanguages(
103 const UnicodeText& context, CodepointSpan selection_indices) const {
104 const float penalize_ratio = options_->backoff_options()->penalize_ratio();
105 const int min_text_size = options_->backoff_options()->min_text_size();
106 if (selection_indices.second - selection_indices.first < min_text_size &&
107 penalize_ratio <= 0) {
108 return {};
109 }
110
111 const UnicodeText entity =
112 UnicodeText::Substring(context, selection_indices.first,
113 selection_indices.second, /*do_copy=*/false);
114 const std::vector<std::pair<std::string, float>> lang_id_result =
115 langid::GetPredictions(langid_model_, entity.data(), entity.size_bytes());
116
117 const float more_text_score_ratio =
118 1.0f - options_->backoff_options()->subject_text_score_ratio();
119 std::vector<std::pair<std::string, float>> more_lang_id_results;
120 if (more_text_score_ratio >= 0) {
121 const UnicodeText entity_with_context = TokenAlignedSubstringAroundSpan(
122 context, selection_indices, min_text_size);
123 more_lang_id_results =
124 langid::GetPredictions(langid_model_, entity_with_context.data(),
125 entity_with_context.size_bytes());
126 }
127
128 const float subject_text_score_ratio =
129 options_->backoff_options()->subject_text_score_ratio();
130
131 std::map<std::string, float> result_map;
132 for (const auto& [language, score] : lang_id_result) {
133 result_map[language] = subject_text_score_ratio * score;
134 }
135 for (const auto& [language, score] : more_lang_id_results) {
136 result_map[language] += more_text_score_ratio * score * penalize_ratio;
137 }
138
139 std::vector<TranslateAnnotator::LanguageConfidence> result;
140 result.reserve(result_map.size());
141 for (const auto& [key, value] : result_map) {
142 result.push_back({key, value});
143 }
144
145 std::sort(result.begin(), result.end(),
146 [](TranslateAnnotator::LanguageConfidence& a,
147 TranslateAnnotator::LanguageConfidence& b) {
148 return a.confidence > b.confidence;
149 });
150 return result;
151 }
152
153 UnicodeText::const_iterator
FindIndexOfNextWhitespaceOrPunctuation(const UnicodeText & text,int start_index,int direction) const154 TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation(
155 const UnicodeText& text, int start_index, int direction) const {
156 TC3_CHECK(direction == 1 || direction == -1);
157 auto it = text.begin();
158 std::advance(it, start_index);
159 while (it > text.begin() && it < text.end()) {
160 if (unilib_->IsWhitespace(*it) || unilib_->IsPunctuation(*it)) {
161 break;
162 }
163 std::advance(it, direction);
164 }
165 return it;
166 }
167
TokenAlignedSubstringAroundSpan(const UnicodeText & text,CodepointSpan indices,int minimum_length) const168 UnicodeText TranslateAnnotator::TokenAlignedSubstringAroundSpan(
169 const UnicodeText& text, CodepointSpan indices, int minimum_length) const {
170 const int text_size_codepoints = text.size_codepoints();
171 if (text_size_codepoints < minimum_length) {
172 return UnicodeText(text, /*do_copy=*/false);
173 }
174
175 const int start = indices.first;
176 const int end = indices.second;
177 const int length = end - start;
178 if (length >= minimum_length) {
179 return UnicodeText::Substring(text, start, end, /*do_copy=*/false);
180 }
181
182 const int offset = (minimum_length - length) / 2;
183 const int iter_start = std::max(
184 0, std::min(start - offset, text_size_codepoints - minimum_length));
185 const int iter_end =
186 std::min(text_size_codepoints, iter_start + minimum_length);
187
188 auto it_start = FindIndexOfNextWhitespaceOrPunctuation(text, iter_start, -1);
189 const auto it_end = FindIndexOfNextWhitespaceOrPunctuation(text, iter_end, 1);
190
191 // The it_start now points to whitespace/punctuation (unless it reached the
192 // beginning of the string). So we'll move it one position forward to point to
193 // the actual text.
194 if (it_start != it_end && unilib_->IsWhitespace(*it_start)) {
195 std::advance(it_start, 1);
196 }
197
198 return UnicodeText::Substring(it_start, it_end, /*do_copy=*/false);
199 }
200
201 } // namespace libtextclassifier3
202