1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/translate/translate.h"
18 
19 #include <memory>
20 
21 #include "annotator/model_generated.h"
22 #include "utils/test-data-test-utils.h"
23 #include "lang_id/fb_model/lang-id-from-fb.h"
24 #include "lang_id/lang-id.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 
28 namespace libtextclassifier3 {
29 namespace {
30 
31 using testing::AllOf;
32 using testing::Field;
33 
TestingTranslateAnnotatorOptions()34 const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() {
35   static const flatbuffers::DetachedBuffer* options_data = []() {
36     TranslateAnnotatorOptionsT options;
37     options.enabled = true;
38     options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
39     options.backoff_options.reset(
40         new TranslateAnnotatorOptions_::BackoffOptionsT());
41 
42     flatbuffers::FlatBufferBuilder builder;
43     builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
44     return new flatbuffers::DetachedBuffer(builder.Release());
45   }();
46 
47   return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data());
48 }
49 
50 class TestingTranslateAnnotator : public TranslateAnnotator {
51  public:
52   // Make these protected members public for tests.
53   using TranslateAnnotator::BackoffDetectLanguages;
54   using TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation;
55   using TranslateAnnotator::TokenAlignedSubstringAroundSpan;
56   using TranslateAnnotator::TranslateAnnotator;
57 };
58 
GetModelPath()59 std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
60 
61 class TranslateAnnotatorTest : public ::testing::Test {
62  protected:
TranslateAnnotatorTest()63   TranslateAnnotatorTest()
64       : INIT_UNILIB_FOR_TESTING(unilib_),
65         langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
66             GetModelPath() + "lang_id.smfb")),
67         translate_annotator_(TestingTranslateAnnotatorOptions(),
68                              langid_model_.get(), &unilib_) {}
69 
70   UniLib unilib_;
71   std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model_;
72   TestingTranslateAnnotator translate_annotator_;
73 };
74 
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishGetsTranslateActionForCzech)75 TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
76   ClassificationResult classification;
77   EXPECT_TRUE(translate_annotator_.ClassifyText(
78       UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {18, 28},
79       "en", &classification));
80 
81   EXPECT_THAT(classification,
82               AllOf(Field(&ClassificationResult::collection, "translate")));
83   const EntityData* entity_data =
84       GetEntityData(classification.serialized_entity_data.data());
85   const auto predictions =
86       entity_data->translate()->language_prediction_results();
87   EXPECT_EQ(predictions->size(), 1);
88   EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "cs");
89   EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
90   EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
91 }
92 
TEST_F(TranslateAnnotatorTest,EntityDataIsSet)93 TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
94   ClassificationResult classification;
95   EXPECT_TRUE(translate_annotator_.ClassifyText(UTF8ToUnicodeText("学校"),
96                                                 {0, 2}, "en", &classification));
97 
98   EXPECT_THAT(classification,
99               AllOf(Field(&ClassificationResult::collection, "translate")));
100   const EntityData* entity_data =
101       GetEntityData(classification.serialized_entity_data.data());
102   const auto predictions =
103       entity_data->translate()->language_prediction_results();
104   EXPECT_EQ(predictions->size(), 2);
105   EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "zh");
106   EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
107   EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
108   EXPECT_EQ(predictions->Get(1)->language_tag()->str(), "ja");
109   EXPECT_TRUE(predictions->Get(0)->confidence_score() >=
110               predictions->Get(1)->confidence_score());
111 }
112 
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishDoesntGetTranslateActionForEnglish)113 TEST_F(TranslateAnnotatorTest,
114        WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
115   ClassificationResult classification;
116   EXPECT_FALSE(translate_annotator_.ClassifyText(
117       UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "en",
118       &classification));
119 }
120 
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech)121 TEST_F(TranslateAnnotatorTest,
122        WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech) {
123   ClassificationResult classification;
124   EXPECT_TRUE(translate_annotator_.ClassifyText(
125       UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {8, 15},
126       "de,en,ja", &classification));
127 
128   EXPECT_THAT(classification,
129               AllOf(Field(&ClassificationResult::collection, "translate")));
130 }
131 
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish)132 TEST_F(TranslateAnnotatorTest,
133        WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish) {
134   ClassificationResult classification;
135   EXPECT_FALSE(translate_annotator_.ClassifyText(
136       UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "cs,en,de,ja",
137       &classification));
138 }
139 
TEST_F(TranslateAnnotatorTest,FindIndexOfNextWhitespaceOrPunctuation)140 TEST_F(TranslateAnnotatorTest, FindIndexOfNextWhitespaceOrPunctuation) {
141   const UnicodeText text =
142       UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
143 
144   EXPECT_EQ(
145       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 0, -1),
146       text.begin());
147   EXPECT_EQ(
148       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 35, 1),
149       text.end());
150   EXPECT_EQ(
151       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, -1),
152       std::next(text.begin(), 6));
153   EXPECT_EQ(
154       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, 1),
155       std::next(text.begin(), 13));
156 }
157 
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringAroundSpan)158 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringAroundSpan) {
159   const UnicodeText text =
160       UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
161 
162   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
163                 text, {35, 37}, /*minimum_length=*/100),
164             text);
165   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
166                 text, {35, 37}, /*minimum_length=*/0),
167             UTF8ToUnicodeText("ač"));
168   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
169                 text, {35, 37}, /*minimum_length=*/3),
170             UTF8ToUnicodeText("stříkaček"));
171   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
172                 text, {35, 37}, /*minimum_length=*/10),
173             UTF8ToUnicodeText("stříkaček"));
174   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
175                 text, {35, 37}, /*minimum_length=*/11),
176             UTF8ToUnicodeText("stříbrných stříkaček"));
177 
178   const UnicodeText text_no_whitespace =
179       UTF8ToUnicodeText("reallyreallylongstring");
180   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
181                 text_no_whitespace, {10, 11}, /*minimum_length=*/2),
182             UTF8ToUnicodeText("reallyreallylongstring"));
183 }
184 
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringWhitespaceText)185 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringWhitespaceText) {
186   const UnicodeText text = UTF8ToUnicodeText("            ");
187 
188   // Shouldn't modify the selection in case it's all whitespace.
189   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
190                 text, {5, 7}, /*minimum_length=*/3),
191             UTF8ToUnicodeText("  "));
192   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
193                 text, {5, 5}, /*minimum_length=*/1),
194             UTF8ToUnicodeText(""));
195 }
196 
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringMostlyWhitespaceText)197 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringMostlyWhitespaceText) {
198   const UnicodeText text = UTF8ToUnicodeText("a        a");
199 
200   // Should still select the whole text even if pointing to whitespace
201   // initially.
202   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
203                 text, {5, 7}, /*minimum_length=*/11),
204             UTF8ToUnicodeText("a        a"));
205 }
206 
207 }  // namespace
208 }  // namespace libtextclassifier3
209