1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/translate/translate.h"
18 
19 #include <memory>
20 
21 #include "annotator/collections.h"
22 #include "annotator/entity-data_generated.h"
23 #include "annotator/types.h"
24 #include "lang_id/lang-id-wrapper.h"
25 #include "utils/base/logging.h"
26 #include "utils/i18n/locale.h"
27 #include "utils/utf8/unicodetext.h"
28 #include "lang_id/lang-id.h"
29 
30 namespace libtextclassifier3 {
31 
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,const std::string & user_familiar_language_tags,ClassificationResult * classification_result) const32 bool TranslateAnnotator::ClassifyText(
33     const UnicodeText& context, CodepointSpan selection_indices,
34     const std::string& user_familiar_language_tags,
35     ClassificationResult* classification_result) const {
36   std::vector<TranslateAnnotator::LanguageConfidence> confidences;
37   if (options_->algorithm() ==
38       TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
39     if (options_->backoff_options() == nullptr) {
40       TC3_LOG(WARNING) << "No backoff options specified. Returning.";
41       return false;
42     }
43     confidences = BackoffDetectLanguages(context, selection_indices);
44   }
45 
46   if (confidences.empty()) {
47     return false;
48   }
49 
50   std::vector<Locale> user_familiar_languages;
51   if (!ParseLocales(user_familiar_language_tags, &user_familiar_languages)) {
52     TC3_LOG(WARNING) << "Couldn't parse the user-understood languages.";
53     return false;
54   }
55   if (user_familiar_languages.empty()) {
56     TC3_VLOG(INFO) << "user_familiar_languages is not set, not suggesting "
57                       "translate action.";
58     return false;
59   }
60   bool user_can_understand_language_of_text = false;
61   for (const Locale& locale : user_familiar_languages) {
62     if (locale.Language() == confidences[0].language) {
63       user_can_understand_language_of_text = true;
64       break;
65     }
66   }
67 
68   if (!user_can_understand_language_of_text) {
69     classification_result->collection = Collections::Translate();
70     classification_result->score = options_->score();
71     classification_result->priority_score = options_->priority_score();
72     classification_result->serialized_entity_data =
73         CreateSerializedEntityData(confidences);
74     return true;
75   }
76 
77   return false;
78 }
79 
CreateSerializedEntityData(const std::vector<TranslateAnnotator::LanguageConfidence> & confidences) const80 std::string TranslateAnnotator::CreateSerializedEntityData(
81     const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
82     const {
83   EntityDataT entity_data;
84   entity_data.translate.reset(new EntityData_::TranslateT());
85 
86   for (const LanguageConfidence& confidence : confidences) {
87     EntityData_::Translate_::LanguagePredictionResultT*
88         language_prediction_result =
89             new EntityData_::Translate_::LanguagePredictionResultT();
90     language_prediction_result->language_tag = confidence.language;
91     language_prediction_result->confidence_score = confidence.confidence;
92     entity_data.translate->language_prediction_results.emplace_back(
93         language_prediction_result);
94   }
95   flatbuffers::FlatBufferBuilder builder;
96   FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
97   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
98                      builder.GetSize());
99 }
100 
101 std::vector<TranslateAnnotator::LanguageConfidence>
BackoffDetectLanguages(const UnicodeText & context,CodepointSpan selection_indices) const102 TranslateAnnotator::BackoffDetectLanguages(
103     const UnicodeText& context, CodepointSpan selection_indices) const {
104   const float penalize_ratio = options_->backoff_options()->penalize_ratio();
105   const int min_text_size = options_->backoff_options()->min_text_size();
106   if (selection_indices.second - selection_indices.first < min_text_size &&
107       penalize_ratio <= 0) {
108     return {};
109   }
110 
111   const UnicodeText entity =
112       UnicodeText::Substring(context, selection_indices.first,
113                              selection_indices.second, /*do_copy=*/false);
114   const std::vector<std::pair<std::string, float>> lang_id_result =
115       langid::GetPredictions(langid_model_, entity.data(), entity.size_bytes());
116 
117   const float more_text_score_ratio =
118       1.0f - options_->backoff_options()->subject_text_score_ratio();
119   std::vector<std::pair<std::string, float>> more_lang_id_results;
120   if (more_text_score_ratio >= 0) {
121     const UnicodeText entity_with_context = TokenAlignedSubstringAroundSpan(
122         context, selection_indices, min_text_size);
123     more_lang_id_results =
124         langid::GetPredictions(langid_model_, entity_with_context.data(),
125                                entity_with_context.size_bytes());
126   }
127 
128   const float subject_text_score_ratio =
129       options_->backoff_options()->subject_text_score_ratio();
130 
131   std::map<std::string, float> result_map;
132   for (const auto& [language, score] : lang_id_result) {
133     result_map[language] = subject_text_score_ratio * score;
134   }
135   for (const auto& [language, score] : more_lang_id_results) {
136     result_map[language] += more_text_score_ratio * score * penalize_ratio;
137   }
138 
139   std::vector<TranslateAnnotator::LanguageConfidence> result;
140   result.reserve(result_map.size());
141   for (const auto& [key, value] : result_map) {
142     result.push_back({key, value});
143   }
144 
145   std::sort(result.begin(), result.end(),
146             [](TranslateAnnotator::LanguageConfidence& a,
147                TranslateAnnotator::LanguageConfidence& b) {
148               return a.confidence > b.confidence;
149             });
150   return result;
151 }
152 
153 UnicodeText::const_iterator
FindIndexOfNextWhitespaceOrPunctuation(const UnicodeText & text,int start_index,int direction) const154 TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation(
155     const UnicodeText& text, int start_index, int direction) const {
156   TC3_CHECK(direction == 1 || direction == -1);
157   auto it = text.begin();
158   std::advance(it, start_index);
159   while (it > text.begin() && it < text.end()) {
160     if (unilib_->IsWhitespace(*it) || unilib_->IsPunctuation(*it)) {
161       break;
162     }
163     std::advance(it, direction);
164   }
165   return it;
166 }
167 
TokenAlignedSubstringAroundSpan(const UnicodeText & text,CodepointSpan indices,int minimum_length) const168 UnicodeText TranslateAnnotator::TokenAlignedSubstringAroundSpan(
169     const UnicodeText& text, CodepointSpan indices, int minimum_length) const {
170   const int text_size_codepoints = text.size_codepoints();
171   if (text_size_codepoints < minimum_length) {
172     return UnicodeText(text, /*do_copy=*/false);
173   }
174 
175   const int start = indices.first;
176   const int end = indices.second;
177   const int length = end - start;
178   if (length >= minimum_length) {
179     return UnicodeText::Substring(text, start, end, /*do_copy=*/false);
180   }
181 
182   const int offset = (minimum_length - length) / 2;
183   const int iter_start = std::max(
184       0, std::min(start - offset, text_size_codepoints - minimum_length));
185   const int iter_end =
186       std::min(text_size_codepoints, iter_start + minimum_length);
187 
188   auto it_start = FindIndexOfNextWhitespaceOrPunctuation(text, iter_start, -1);
189   const auto it_end = FindIndexOfNextWhitespaceOrPunctuation(text, iter_end, 1);
190 
191   // The it_start now points to whitespace/punctuation (unless it reached the
192   // beginning of the string). So we'll move it one position forward to point to
193   // the actual text.
194   if (it_start != it_end && unilib_->IsWhitespace(*it_start)) {
195     std::advance(it_start, 1);
196   }
197 
198   return UnicodeText::Substring(it_start, it_end, /*do_copy=*/false);
199 }
200 
201 }  // namespace libtextclassifier3
202