1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/translate/translate.h"
18
19 #include <memory>
20
21 #include "annotator/model_generated.h"
22 #include "utils/test-data-test-utils.h"
23 #include "lang_id/fb_model/lang-id-from-fb.h"
24 #include "lang_id/lang-id.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27
28 namespace libtextclassifier3 {
29 namespace {
30
31 using testing::AllOf;
32 using testing::Field;
33
TestingTranslateAnnotatorOptions()34 const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() {
35 static const flatbuffers::DetachedBuffer* options_data = []() {
36 TranslateAnnotatorOptionsT options;
37 options.enabled = true;
38 options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
39 options.backoff_options.reset(
40 new TranslateAnnotatorOptions_::BackoffOptionsT());
41
42 flatbuffers::FlatBufferBuilder builder;
43 builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
44 return new flatbuffers::DetachedBuffer(builder.Release());
45 }();
46
47 return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data());
48 }
49
50 class TestingTranslateAnnotator : public TranslateAnnotator {
51 public:
52 // Make these protected members public for tests.
53 using TranslateAnnotator::BackoffDetectLanguages;
54 using TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation;
55 using TranslateAnnotator::TokenAlignedSubstringAroundSpan;
56 using TranslateAnnotator::TranslateAnnotator;
57 };
58
GetModelPath()59 std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
60
61 class TranslateAnnotatorTest : public ::testing::Test {
62 protected:
TranslateAnnotatorTest()63 TranslateAnnotatorTest()
64 : INIT_UNILIB_FOR_TESTING(unilib_),
65 langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
66 GetModelPath() + "lang_id.smfb")),
67 translate_annotator_(TestingTranslateAnnotatorOptions(),
68 langid_model_.get(), &unilib_) {}
69
70 UniLib unilib_;
71 std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model_;
72 TestingTranslateAnnotator translate_annotator_;
73 };
74
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishGetsTranslateActionForCzech)75 TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
76 ClassificationResult classification;
77 EXPECT_TRUE(translate_annotator_.ClassifyText(
78 UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {18, 28},
79 "en", &classification));
80
81 EXPECT_THAT(classification,
82 AllOf(Field(&ClassificationResult::collection, "translate")));
83 const EntityData* entity_data =
84 GetEntityData(classification.serialized_entity_data.data());
85 const auto predictions =
86 entity_data->translate()->language_prediction_results();
87 EXPECT_EQ(predictions->size(), 1);
88 EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "cs");
89 EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
90 EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
91 }
92
TEST_F(TranslateAnnotatorTest,EntityDataIsSet)93 TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
94 ClassificationResult classification;
95 EXPECT_TRUE(translate_annotator_.ClassifyText(UTF8ToUnicodeText("学校"),
96 {0, 2}, "en", &classification));
97
98 EXPECT_THAT(classification,
99 AllOf(Field(&ClassificationResult::collection, "translate")));
100 const EntityData* entity_data =
101 GetEntityData(classification.serialized_entity_data.data());
102 const auto predictions =
103 entity_data->translate()->language_prediction_results();
104 EXPECT_EQ(predictions->size(), 2);
105 EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "zh");
106 EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
107 EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
108 EXPECT_EQ(predictions->Get(1)->language_tag()->str(), "ja");
109 EXPECT_TRUE(predictions->Get(0)->confidence_score() >=
110 predictions->Get(1)->confidence_score());
111 }
112
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishDoesntGetTranslateActionForEnglish)113 TEST_F(TranslateAnnotatorTest,
114 WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
115 ClassificationResult classification;
116 EXPECT_FALSE(translate_annotator_.ClassifyText(
117 UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "en",
118 &classification));
119 }
120
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech)121 TEST_F(TranslateAnnotatorTest,
122 WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech) {
123 ClassificationResult classification;
124 EXPECT_TRUE(translate_annotator_.ClassifyText(
125 UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {8, 15},
126 "de,en,ja", &classification));
127
128 EXPECT_THAT(classification,
129 AllOf(Field(&ClassificationResult::collection, "translate")));
130 }
131
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish)132 TEST_F(TranslateAnnotatorTest,
133 WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish) {
134 ClassificationResult classification;
135 EXPECT_FALSE(translate_annotator_.ClassifyText(
136 UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "cs,en,de,ja",
137 &classification));
138 }
139
TEST_F(TranslateAnnotatorTest,FindIndexOfNextWhitespaceOrPunctuation)140 TEST_F(TranslateAnnotatorTest, FindIndexOfNextWhitespaceOrPunctuation) {
141 const UnicodeText text =
142 UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
143
144 EXPECT_EQ(
145 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 0, -1),
146 text.begin());
147 EXPECT_EQ(
148 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 35, 1),
149 text.end());
150 EXPECT_EQ(
151 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, -1),
152 std::next(text.begin(), 6));
153 EXPECT_EQ(
154 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, 1),
155 std::next(text.begin(), 13));
156 }
157
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringAroundSpan)158 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringAroundSpan) {
159 const UnicodeText text =
160 UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
161
162 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
163 text, {35, 37}, /*minimum_length=*/100),
164 text);
165 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
166 text, {35, 37}, /*minimum_length=*/0),
167 UTF8ToUnicodeText("ač"));
168 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
169 text, {35, 37}, /*minimum_length=*/3),
170 UTF8ToUnicodeText("stříkaček"));
171 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
172 text, {35, 37}, /*minimum_length=*/10),
173 UTF8ToUnicodeText("stříkaček"));
174 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
175 text, {35, 37}, /*minimum_length=*/11),
176 UTF8ToUnicodeText("stříbrných stříkaček"));
177
178 const UnicodeText text_no_whitespace =
179 UTF8ToUnicodeText("reallyreallylongstring");
180 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
181 text_no_whitespace, {10, 11}, /*minimum_length=*/2),
182 UTF8ToUnicodeText("reallyreallylongstring"));
183 }
184
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringWhitespaceText)185 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringWhitespaceText) {
186 const UnicodeText text = UTF8ToUnicodeText(" ");
187
188 // Shouldn't modify the selection in case it's all whitespace.
189 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
190 text, {5, 7}, /*minimum_length=*/3),
191 UTF8ToUnicodeText(" "));
192 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
193 text, {5, 5}, /*minimum_length=*/1),
194 UTF8ToUnicodeText(""));
195 }
196
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringMostlyWhitespaceText)197 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringMostlyWhitespaceText) {
198 const UnicodeText text = UTF8ToUnicodeText("a a");
199
200 // Should still select the whole text even if pointing to whitespace
201 // initially.
202 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
203 text, {5, 7}, /*minimum_length=*/11),
204 UTF8ToUnicodeText("a a"));
205 }
206
207 } // namespace
208 } // namespace libtextclassifier3
209