1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/annotator_test-include.h"
18 
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <type_traits>
23 
24 #include "annotator/annotator.h"
25 #include "annotator/collections.h"
26 #include "annotator/model_generated.h"
27 #include "annotator/test-utils.h"
28 #include "annotator/types-test-util.h"
29 #include "annotator/types.h"
30 #include "utils/grammar/utils/locale-shard-map.h"
31 #include "utils/grammar/utils/rules.h"
32 #include "utils/testing/annotator.h"
33 #include "lang_id/fb_model/lang-id-from-fb.h"
34 #include "lang_id/lang-id.h"
35 
36 namespace libtextclassifier3 {
37 namespace test_internal {
38 
39 using ::testing::Contains;
40 using ::testing::ElementsAre;
41 using ::testing::ElementsAreArray;
42 using ::testing::Eq;
43 using ::testing::IsEmpty;
44 using ::testing::UnorderedElementsAreArray;
45 
GetTestModelPath()46 std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
47 
GetModelWithVocabPath()48 std::string GetModelWithVocabPath() {
49   return GetModelPath() + "test_vocab_model.fb";
50 }
51 
GetTestModelWithDatetimeRegEx()52 std::string GetTestModelWithDatetimeRegEx() {
53   std::string model_buffer = ReadFile(GetTestModelPath());
54   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
55     model->datetime_grammar_model.reset(nullptr);
56   });
57   return model_buffer;
58 }
59 
ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan> & result,const std::string & currency,const std::string & amount,const int whole_part,const int decimal_part,const int nanos)60 void ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan>& result,
61                               const std::string& currency,
62                               const std::string& amount, const int whole_part,
63                               const int decimal_part, const int nanos) {
64   ASSERT_GT(result.size(), 0);
65   ASSERT_GT(result[0].classification.size(), 0);
66   ASSERT_EQ(result[0].classification[0].collection, "money");
67 
68   const EntityData* entity_data =
69       GetEntityData(result[0].classification[0].serialized_entity_data.data());
70   ASSERT_NE(entity_data, nullptr);
71   ASSERT_NE(entity_data->money(), nullptr);
72   EXPECT_EQ(entity_data->money()->unnormalized_currency()->str(), currency);
73   EXPECT_EQ(entity_data->money()->unnormalized_amount()->str(), amount);
74   EXPECT_EQ(entity_data->money()->amount_whole_part(), whole_part);
75   EXPECT_EQ(entity_data->money()->amount_decimal_part(), decimal_part);
76   EXPECT_EQ(entity_data->money()->nanos(), nanos);
77 }
78 
TEST_F(AnnotatorTest,EmbeddingExecutorLoadingFails)79 TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
80   std::unique_ptr<Annotator> classifier =
81       Annotator::FromPath(GetModelPath() + "wrong_embeddings.fb", unilib_.get(),
82                           calendarlib_.get());
83   EXPECT_FALSE(classifier);
84 }
85 
VerifyClassifyText(const Annotator * classifier)86 void VerifyClassifyText(const Annotator* classifier) {
87   ASSERT_TRUE(classifier);
88 
89   EXPECT_EQ("other",
90             FirstResult(classifier->ClassifyText(
91                 "this afternoon Barack Obama gave a speech at", {15, 27})));
92   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
93                          "Call me at (800) 123-456 today", {11, 24})));
94 
95   // More lines.
96   EXPECT_EQ("other",
97             FirstResult(classifier->ClassifyText(
98                 "this afternoon Barack Obama gave a speech at|Visit "
99                 "www.google.com every today!|Call me at (800) 123-456 today.",
100                 {15, 27})));
101   EXPECT_EQ("phone",
102             FirstResult(classifier->ClassifyText(
103                 "this afternoon Barack Obama gave a speech at|Visit "
104                 "www.google.com every today!|Call me at (800) 123-456 today.",
105                 {90, 103})));
106 
107   // Single word.
108   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
109   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
110 
111   // Junk. These should not crash the test.
112   classifier->ClassifyText("", {0, 0});
113   classifier->ClassifyText("asdf", {0, 0});
114   classifier->ClassifyText("asdf", {0, 27});
115   classifier->ClassifyText("asdf", {-30, 300});
116   classifier->ClassifyText("asdf", {-10, -1});
117   classifier->ClassifyText("asdf", {100, 17});
118   classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5});
119 
120   // Test invalid utf8 input.
121   EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
122                                      "\xf0\x9f\x98\x8b\x8b", {0, 0})));
123 }
124 
TEST_F(AnnotatorTest,ClassifyText)125 TEST_F(AnnotatorTest, ClassifyText) {
126   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
127       GetTestModelPath(), unilib_.get(), calendarlib_.get());
128   VerifyClassifyText(classifier.get());
129 }
130 
TEST_F(AnnotatorTest,ClassifyTextLocalesAndDictionary)131 TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
132   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
133       GetTestModelPath(), unilib_.get(), calendarlib_.get());
134   ASSERT_TRUE(classifier);
135 
136   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 7})));
137 
138   ClassificationOptions classification_options;
139   classification_options.detected_text_language_tags = "en";
140   EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
141                               "isotope", {0, 7}, classification_options)));
142 
143   classification_options.detected_text_language_tags = "uz";
144   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
145                          "isotope", {0, 7}, classification_options)));
146 }
147 
TEST_F(AnnotatorTest,ClassifyTextUseVocabAnnotatorWithoutVocabModel)148 TEST_F(AnnotatorTest, ClassifyTextUseVocabAnnotatorWithoutVocabModel) {
149   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
150       GetTestModelPath(), unilib_.get(), calendarlib_.get());
151   ASSERT_TRUE(classifier);
152 
153   ClassificationOptions classification_options;
154   classification_options.detected_text_language_tags = "en";
155   classification_options.use_vocab_annotator = true;
156 
157   EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
158                               "isotope", {0, 7}, classification_options)));
159 }
160 
161 #ifdef TC3_VOCAB_ANNOTATOR_IMPL
TEST_F(AnnotatorTest,ClassifyTextWithVocabModel)162 TEST_F(AnnotatorTest, ClassifyTextWithVocabModel) {
163   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
164       GetModelWithVocabPath(), unilib_.get(), calendarlib_.get());
165   ASSERT_TRUE(classifier);
166 
167   ClassificationOptions classification_options;
168   classification_options.detected_text_language_tags = "en";
169 
170   // The FFModel model does not annotate "integrity" as "dictionary", but the
171   // vocab annotator does. So we can use that to check if the vocab annotator is
172   // in use.
173   classification_options.use_vocab_annotator = true;
174   EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
175                               "integrity", {0, 9}, classification_options)));
176   classification_options.use_vocab_annotator = false;
177   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
178                          "integrity", {0, 9}, classification_options)));
179 }
180 #endif  // TC3_VOCAB_ANNOTATOR_IMPL
181 
TEST_F(AnnotatorTest,ClassifyTextDisabledFail)182 TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
183   const std::string test_model = ReadFile(GetTestModelPath());
184   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
185   TC3_CHECK(unpacked_model != nullptr);
186 
187   unpacked_model->classification_model.clear();
188   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
189   unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
190 
191   flatbuffers::FlatBufferBuilder builder;
192   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
193 
194   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
195       reinterpret_cast<const char*>(builder.GetBufferPointer()),
196       builder.GetSize(), unilib_.get(), calendarlib_.get());
197 
198   // The classification model is still needed for selection scores.
199   ASSERT_FALSE(classifier);
200 }
201 
TEST_F(AnnotatorTest,ClassifyTextDisabled)202 TEST_F(AnnotatorTest, ClassifyTextDisabled) {
203   const std::string test_model = ReadFile(GetTestModelPath());
204   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
205 
206   unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION;
207 
208   flatbuffers::FlatBufferBuilder builder;
209   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
210 
211   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
212       reinterpret_cast<const char*>(builder.GetBufferPointer()),
213       builder.GetSize(), unilib_.get(), calendarlib_.get());
214   ASSERT_TRUE(classifier);
215 
216   EXPECT_THAT(
217       classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
218       IsEmpty());
219 }
220 
TEST_F(AnnotatorTest,ClassifyTextFilteredCollections)221 TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) {
222   const std::string test_model = ReadFile(GetTestModelPath());
223 
224   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
225       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
226   ASSERT_TRUE(classifier);
227 
228   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
229                          "Call me at (800) 123-456 today", {11, 24})));
230 
231   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
232   unpacked_model->output_options.reset(new OutputOptionsT);
233 
234   // Disable phone classification
235   unpacked_model->output_options->filtered_collections_classification.push_back(
236       "phone");
237 
238   flatbuffers::FlatBufferBuilder builder;
239   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
240 
241   classifier = Annotator::FromUnownedBuffer(
242       reinterpret_cast<const char*>(builder.GetBufferPointer()),
243       builder.GetSize(), unilib_.get(), calendarlib_.get());
244   ASSERT_TRUE(classifier);
245 
246   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
247                          "Call me at (800) 123-456 today", {11, 24})));
248 
249   // Check that the address classification still passes.
250   EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
251                            "350 Third Street, Cambridge", {0, 27})));
252 }
253 
TEST_F(AnnotatorTest,ClassifyTextRegularExpression)254 TEST_F(AnnotatorTest, ClassifyTextRegularExpression) {
255   const std::string test_model = ReadFile(GetTestModelPath());
256   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
257 
258   // Add test regex models.
259   unpacked_model->regex_model->patterns.push_back(MakePattern(
260       "person", "Barack Obama", /*enabled_for_classification=*/true,
261       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
262   unpacked_model->regex_model->patterns.push_back(MakePattern(
263       "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
264       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
265   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
266       MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
267                   /*enabled_for_classification=*/true,
268                   /*enabled_for_selection=*/false,
269                   /*enabled_for_annotation=*/false, 1.0);
270   verified_pattern->verification_options.reset(new VerificationOptionsT);
271   verified_pattern->verification_options->verify_luhn_checksum = true;
272   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
273 
274   flatbuffers::FlatBufferBuilder builder;
275   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
276 
277   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
278       reinterpret_cast<const char*>(builder.GetBufferPointer()),
279       builder.GetSize(), unilib_.get(), calendarlib_.get());
280   ASSERT_TRUE(classifier);
281 
282   EXPECT_EQ("flight",
283             FirstResult(classifier->ClassifyText(
284                 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
285   EXPECT_EQ("person",
286             FirstResult(classifier->ClassifyText(
287                 "this afternoon Barack Obama gave a speech at", {15, 27})));
288   EXPECT_EQ("email",
289             FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
290   EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
291                          "Contact me at you@android.com", {14, 29})));
292 
293   EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
294                        "Visit www.google.com every today!", {6, 20})));
295 
296   EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
297   EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
298                                                            {7, 12})));
299   EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
300                                 "cc: 4012 8888 8888 1881", {4, 23})));
301   EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
302                                 "2221 0067 4735 6281", {0, 19})));
303   // Luhn check fails.
304   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
305                                                           {0, 19})));
306 
307   // More lines.
308   EXPECT_EQ("url",
309             FirstResult(classifier->ClassifyText(
310                 "this afternoon Barack Obama gave a speech at|Visit "
311                 "www.google.com every today!|Call me at (800) 123-456 today.",
312                 {51, 65})));
313 }
314 
315 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionLuaVerification)316 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionLuaVerification) {
317   const std::string test_model = ReadFile(GetTestModelPath());
318   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
319 
320   // Add test regex models.
321   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
322       MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
323                   /*enabled_for_classification=*/true,
324                   /*enabled_for_selection=*/false,
325                   /*enabled_for_annotation=*/false, 1.0);
326   verified_pattern->verification_options.reset(new VerificationOptionsT);
327   verified_pattern->verification_options->lua_verifier = 0;
328   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
329   unpacked_model->regex_model->lua_verifier.push_back(
330       "return match[2].text==\"99\"");
331 
332   flatbuffers::FlatBufferBuilder builder;
333   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
334 
335   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
336       reinterpret_cast<const char*>(builder.GetBufferPointer()),
337       builder.GetSize(), unilib_.get(), calendarlib_.get());
338   ASSERT_TRUE(classifier);
339 
340   // Custom rule triggers and is correctly verified.
341   EXPECT_EQ("parcel_tracking", FirstResult(classifier->ClassifyText(
342                                    "99-00-123456-12345678", {0, 21})));
343 
344   // Custom verification fails.
345   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
346                          "90-00-123456-12345678", {0, 21})));
347 }
348 #endif  // TC3_DISABLE_LUA
349 
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityData)350 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
351   const std::string test_model = ReadFile(GetTestModelPath());
352   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
353 
354   // Add fake entity schema metadata.
355   AddTestEntitySchemaData(unpacked_model.get());
356 
357   AddTestRegexModel(unpacked_model.get());
358 
359   flatbuffers::FlatBufferBuilder builder;
360   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
361 
362   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
363       reinterpret_cast<const char*>(builder.GetBufferPointer()),
364       builder.GetSize(), unilib_.get(), calendarlib_.get());
365   ASSERT_TRUE(classifier);
366 
367   // Check with full name.
368   {
369     auto classifications =
370         classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
371     EXPECT_EQ(1, classifications.size());
372     EXPECT_EQ("person_with_age", classifications[0].collection);
373 
374     // Check entity data.
375     const flatbuffers::Table* entity =
376         flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
377             classifications[0].serialized_entity_data.data()));
378     EXPECT_EQ(
379         entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
380         "Barack");
381     EXPECT_EQ(
382         entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
383         "Obama");
384     // Check `age`.
385     EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
386 
387     // Check `is_alive`.
388     EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
389 
390     // Check `former_us_president`.
391     EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
392   }
393 
394   // Check only with first name.
395   {
396     auto classifications =
397         classifier->ClassifyText("Barack is 57 years old", {0, 22});
398     EXPECT_EQ(1, classifications.size());
399     EXPECT_EQ("person_with_age", classifications[0].collection);
400 
401     // Check entity data.
402     const flatbuffers::Table* entity =
403         flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
404             classifications[0].serialized_entity_data.data()));
405     EXPECT_EQ(
406         entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
407         "Barack");
408 
409     // Check `age`.
410     EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
411 
412     // Check `is_alive`.
413     EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
414 
415     // Check `former_us_president`.
416     EXPECT_FALSE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
417   }
418 }
419 
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityDataNormalization)420 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityDataNormalization) {
421   const std::string test_model = ReadFile(GetTestModelPath());
422   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
423 
424   // Add fake entity schema metadata.
425   AddTestEntitySchemaData(unpacked_model.get());
426 
427   AddTestRegexModel(unpacked_model.get());
428 
429   // Upper case last name as post-processing.
430   RegexModel_::PatternT* pattern =
431       unpacked_model->regex_model->patterns.back().get();
432   pattern->capturing_group[2]->normalization_options.reset(
433       new NormalizationOptionsT);
434   pattern->capturing_group[2]
435       ->normalization_options->codepointwise_normalization =
436       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
437 
438   flatbuffers::FlatBufferBuilder builder;
439   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
440 
441   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
442       reinterpret_cast<const char*>(builder.GetBufferPointer()),
443       builder.GetSize(), unilib_.get(), calendarlib_.get());
444   ASSERT_TRUE(classifier);
445 
446   auto classifications =
447       classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
448   EXPECT_EQ(1, classifications.size());
449   EXPECT_EQ("person_with_age", classifications[0].collection);
450 
451   // Check entity data normalization.
452   const flatbuffers::Table* entity =
453       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
454           classifications[0].serialized_entity_data.data()));
455   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
456             "OBAMA");
457 }
458 
TEST_F(AnnotatorTest,ClassifyTextPriorityResolution)459 TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) {
460   const std::string test_model = ReadFile(GetTestModelPath());
461   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
462   TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
463   // Add test regex models.
464   unpacked_model->regex_model->patterns.clear();
465   unpacked_model->regex_model->patterns.push_back(MakePattern(
466       "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
467       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
468       /*score=*/1.0, /*priority_score=*/1.0));
469   unpacked_model->regex_model->patterns.push_back(MakePattern(
470       "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
471       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
472       /*score=*/1.0, /*priority_score=*/0.0));
473 
474   {
475     flatbuffers::FlatBufferBuilder builder;
476     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
477     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
478         reinterpret_cast<const char*>(builder.GetBufferPointer()),
479         builder.GetSize(), unilib_.get(), calendarlib_.get());
480     ASSERT_TRUE(classifier);
481 
482     EXPECT_EQ("flight1",
483               FirstResult(classifier->ClassifyText(
484                   "Your flight LX373 is delayed by 3 hours.", {12, 17})));
485   }
486 
487   unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
488   {
489     flatbuffers::FlatBufferBuilder builder;
490     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
491     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
492         reinterpret_cast<const char*>(builder.GetBufferPointer()),
493         builder.GetSize(), unilib_.get(), calendarlib_.get());
494     ASSERT_TRUE(classifier);
495 
496     EXPECT_EQ("flight2",
497               FirstResult(classifier->ClassifyText(
498                   "Your flight LX373 is delayed by 3 hours.", {12, 17})));
499   }
500 }
501 
TEST_F(AnnotatorTest,AnnotatePriorityResolution)502 TEST_F(AnnotatorTest, AnnotatePriorityResolution) {
503   const std::string test_model = ReadFile(GetTestModelPath());
504   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
505   TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
506   // Add test regex models. One of them  has higher priority score than
507   // the other. We'll test that always the one with higher priority score
508   // ends up winning.
509   unpacked_model->regex_model->patterns.clear();
510   const std::string flight_regex = "([a-zA-Z]{2}\\d{2,4})";
511   unpacked_model->regex_model->patterns.push_back(MakePattern(
512       "flight", flight_regex, /*enabled_for_classification=*/true,
513       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
514       /*score=*/1.0, /*priority_score=*/1.0));
515   unpacked_model->regex_model->patterns.push_back(MakePattern(
516       "flight", flight_regex, /*enabled_for_classification=*/true,
517       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
518       /*score=*/1.0, /*priority_score=*/0.0));
519 
520   // "flight" that wins should have a priority score of 1.0.
521   {
522     flatbuffers::FlatBufferBuilder builder;
523     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
524     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
525         reinterpret_cast<const char*>(builder.GetBufferPointer()),
526         builder.GetSize(), unilib_.get(), calendarlib_.get());
527     ASSERT_TRUE(classifier);
528 
529     const std::vector<AnnotatedSpan> results =
530         classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
531     ASSERT_THAT(results, Not(IsEmpty()));
532     EXPECT_THAT(results[0].classification, Not(IsEmpty()));
533     EXPECT_GE(results[0].classification[0].priority_score, 0.9);
534   }
535 
536   // When we increase the priority score, the "flight" that wins should have a
537   // priority score of 3.0.
538   unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
539   {
540     flatbuffers::FlatBufferBuilder builder;
541     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
542     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
543         reinterpret_cast<const char*>(builder.GetBufferPointer()),
544         builder.GetSize(), unilib_.get(), calendarlib_.get());
545     ASSERT_TRUE(classifier);
546 
547     const std::vector<AnnotatedSpan> results =
548         classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
549     ASSERT_THAT(results, Not(IsEmpty()));
550     EXPECT_THAT(results[0].classification, Not(IsEmpty()));
551     EXPECT_GE(results[0].classification[0].priority_score, 2.9);
552   }
553 }
554 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpression)555 TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) {
556   const std::string test_model = ReadFile(GetTestModelPath());
557   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
558 
559   // Add test regex models.
560   unpacked_model->regex_model->patterns.push_back(MakePattern(
561       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
562       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
563   unpacked_model->regex_model->patterns.push_back(MakePattern(
564       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
565       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
566   unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
567   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
568       MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
569                   /*enabled_for_classification=*/false,
570                   /*enabled_for_selection=*/true,
571                   /*enabled_for_annotation=*/false, 1.0);
572   verified_pattern->verification_options.reset(new VerificationOptionsT);
573   verified_pattern->verification_options->verify_luhn_checksum = true;
574   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
575 
576   flatbuffers::FlatBufferBuilder builder;
577   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
578 
579   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
580       reinterpret_cast<const char*>(builder.GetBufferPointer()),
581       builder.GetSize(), unilib_.get(), calendarlib_.get());
582   ASSERT_TRUE(classifier);
583 
584   // Check regular expression selection.
585   EXPECT_EQ(classifier->SuggestSelection(
586                 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
587             CodepointSpan(12, 19));
588   EXPECT_EQ(classifier->SuggestSelection(
589                 "this afternoon Barack Obama gave a speech at", {15, 21}),
590             CodepointSpan(15, 27));
591   EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
592             CodepointSpan(4, 23));
593 }
594 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionCustomSelectionBounds)595 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
596   const std::string test_model = ReadFile(GetTestModelPath());
597   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
598 
599   // Add test regex models.
600   std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
601       MakePattern("date_range",
602                   "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
603                   "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
604                   /*enabled_for_classification=*/false,
605                   /*enabled_for_selection=*/true,
606                   /*enabled_for_annotation=*/false, 1.0);
607   custom_selection_bounds_pattern->capturing_group.emplace_back(
608       new CapturingGroupT);
609   custom_selection_bounds_pattern->capturing_group.emplace_back(
610       new CapturingGroupT);
611   custom_selection_bounds_pattern->capturing_group.emplace_back(
612       new CapturingGroupT);
613   custom_selection_bounds_pattern->capturing_group.emplace_back(
614       new CapturingGroupT);
615   custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
616   custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
617   custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
618   custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
619   unpacked_model->regex_model->patterns.push_back(
620       std::move(custom_selection_bounds_pattern));
621 
622   flatbuffers::FlatBufferBuilder builder;
623   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
624 
625   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
626       reinterpret_cast<const char*>(builder.GetBufferPointer()),
627       builder.GetSize(), unilib_.get(), calendarlib_.get());
628   ASSERT_TRUE(classifier);
629 
630   // Check regular expression selection.
631   EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
632                                          {21, 23}),
633             CodepointSpan(10, 34));
634   EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
635             CodepointSpan(9, 17));
636 }
637 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsModelWins)638 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
639   const std::string test_model = ReadFile(GetTestModelPath());
640   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
641 
642   // Add test regex models.
643   unpacked_model->regex_model->patterns.push_back(MakePattern(
644       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
645       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
646   unpacked_model->regex_model->patterns.push_back(MakePattern(
647       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
648       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
649   unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
650 
651   flatbuffers::FlatBufferBuilder builder;
652   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
653 
654   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
655       reinterpret_cast<const char*>(builder.GetBufferPointer()),
656       builder.GetSize(), unilib_.get(), calendarlib_.get());
657   ASSERT_TRUE(classifier);
658 
659   // Check conflict resolution.
660   EXPECT_EQ(
661       classifier->SuggestSelection(
662           "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
663           {55, 57}),
664       CodepointSpan(26, 62));
665 }
666 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsRegexWins)667 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
668   const std::string test_model = ReadFile(GetTestModelPath());
669   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
670 
671   // Add test regex models.
672   unpacked_model->regex_model->patterns.push_back(MakePattern(
673       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
674       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
675   unpacked_model->regex_model->patterns.push_back(MakePattern(
676       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
677       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
678   unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
679 
680   flatbuffers::FlatBufferBuilder builder;
681   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
682 
683   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
684       reinterpret_cast<const char*>(builder.GetBufferPointer()),
685       builder.GetSize(), unilib_.get(), calendarlib_.get());
686   ASSERT_TRUE(classifier);
687 
688   // Check conflict resolution.
689   EXPECT_EQ(
690       classifier->SuggestSelection(
691           "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
692           {55, 57}),
693       CodepointSpan(55, 62));
694 }
695 
TEST_F(AnnotatorTest,AnnotateRegex)696 TEST_F(AnnotatorTest, AnnotateRegex) {
697   const std::string test_model = ReadFile(GetTestModelPath());
698   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
699 
700   // Add test regex models.
701   unpacked_model->regex_model->patterns.push_back(MakePattern(
702       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
703       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
704   unpacked_model->regex_model->patterns.push_back(MakePattern(
705       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
706       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
707   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
708       MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
709                   /*enabled_for_classification=*/false,
710                   /*enabled_for_selection=*/false,
711                   /*enabled_for_annotation=*/true, 1.0);
712   verified_pattern->verification_options.reset(new VerificationOptionsT);
713   verified_pattern->verification_options->verify_luhn_checksum = true;
714   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
715   flatbuffers::FlatBufferBuilder builder;
716   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
717 
718   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
719       reinterpret_cast<const char*>(builder.GetBufferPointer()),
720       builder.GetSize(), unilib_.get(), calendarlib_.get());
721   ASSERT_TRUE(classifier);
722 
723   const std::string test_string =
724       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
725       "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
726   EXPECT_THAT(classifier->Annotate(test_string),
727               ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
728                                 IsAnnotatedSpan(28, 55, "address"),
729                                 IsAnnotatedSpan(79, 91, "phone"),
730                                 IsAnnotatedSpan(107, 126, "payment_card")}));
731 }
732 
TEST_F(AnnotatorTest,AnnotatesFlightNumbers)733 TEST_F(AnnotatorTest, AnnotatesFlightNumbers) {
734   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
735       GetTestModelPath(), unilib_.get(), calendarlib_.get());
736   ASSERT_TRUE(classifier);
737 
738   // ICAO is only used for selected airlines.
739   // Expected: LX373, EZY1234 and U21234.
740   const std::string test_string = "flights LX373, SWR373, EZY1234, U21234";
741   EXPECT_THAT(classifier->Annotate(test_string),
742               ElementsAreArray({IsAnnotatedSpan(8, 13, "flight"),
743                                 IsAnnotatedSpan(23, 30, "flight"),
744                                 IsAnnotatedSpan(32, 38, "flight")}));
745 }
746 
747 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,AnnotateRegexLuaVerification)748 TEST_F(AnnotatorTest, AnnotateRegexLuaVerification) {
749   const std::string test_model = ReadFile(GetTestModelPath());
750   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
751 
752   // Add test regex models.
753   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
754       MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
755                   /*enabled_for_classification=*/true,
756                   /*enabled_for_selection=*/true,
757                   /*enabled_for_annotation=*/true, 1.0);
758   verified_pattern->verification_options.reset(new VerificationOptionsT);
759   verified_pattern->verification_options->lua_verifier = 0;
760   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
761   unpacked_model->regex_model->lua_verifier.push_back(
762       "return match[2].text==\"99\"");
763 
764   flatbuffers::FlatBufferBuilder builder;
765   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
766 
767   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
768       reinterpret_cast<const char*>(builder.GetBufferPointer()),
769       builder.GetSize(), unilib_.get(), calendarlib_.get());
770   ASSERT_TRUE(classifier);
771 
772   const std::string test_string =
773       "your parcel is on the way: 99-00-123456-12345678";
774   EXPECT_THAT(classifier->Annotate(test_string),
775               ElementsAreArray({IsAnnotatedSpan(27, 48, "parcel_tracking")}));
776 }
777 #endif  // TC3_DISABLE_LUA
778 
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityData)779 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityData) {
780   const std::string test_model = ReadFile(GetTestModelPath());
781   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
782 
783   // Add fake entity schema metadata.
784   AddTestEntitySchemaData(unpacked_model.get());
785 
786   AddTestRegexModel(unpacked_model.get());
787 
788   flatbuffers::FlatBufferBuilder builder;
789   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
790 
791   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
792       reinterpret_cast<const char*>(builder.GetBufferPointer()),
793       builder.GetSize(), unilib_.get(), calendarlib_.get());
794   ASSERT_TRUE(classifier);
795 
796   AnnotationOptions options;
797   options.is_serialized_entity_data_enabled = true;
798   auto annotations =
799       classifier->Annotate("Barack Obama is 57 years old", options);
800   EXPECT_EQ(1, annotations.size());
801   EXPECT_EQ(1, annotations[0].classification.size());
802   EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
803 
804   // Check entity data.
805   const flatbuffers::Table* entity =
806       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
807           annotations[0].classification[0].serialized_entity_data.data()));
808   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
809             "Barack");
810   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
811             "Obama");
812   // Check `age`.
813   EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
814 
815   // Check `is_alive`.
816   EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
817 
818   // Check `former_us_president`.
819   EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
820 }
821 
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataNormalization)822 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataNormalization) {
823   const std::string test_model = ReadFile(GetTestModelPath());
824   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
825 
826   // Add fake entity schema metadata.
827   AddTestEntitySchemaData(unpacked_model.get());
828 
829   AddTestRegexModel(unpacked_model.get());
830 
831   // Upper case last name as post-processing.
832   RegexModel_::PatternT* pattern =
833       unpacked_model->regex_model->patterns.back().get();
834   pattern->capturing_group[2]->normalization_options.reset(
835       new NormalizationOptionsT);
836   pattern->capturing_group[2]
837       ->normalization_options->codepointwise_normalization =
838       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
839 
840   flatbuffers::FlatBufferBuilder builder;
841   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
842 
843   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
844       reinterpret_cast<const char*>(builder.GetBufferPointer()),
845       builder.GetSize(), unilib_.get(), calendarlib_.get());
846   ASSERT_TRUE(classifier);
847 
848   AnnotationOptions options;
849   options.is_serialized_entity_data_enabled = true;
850   auto annotations =
851       classifier->Annotate("Barack Obama is 57 years old", options);
852   EXPECT_EQ(1, annotations.size());
853   EXPECT_EQ(1, annotations[0].classification.size());
854   EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
855 
856   // Check normalization.
857   const flatbuffers::Table* entity =
858       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
859           annotations[0].classification[0].serialized_entity_data.data()));
860   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
861             "OBAMA");
862 }
863 
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataDisabled)864 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataDisabled) {
865   const std::string test_model = ReadFile(GetTestModelPath());
866   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
867 
868   // Add fake entity schema metadata.
869   AddTestEntitySchemaData(unpacked_model.get());
870 
871   AddTestRegexModel(unpacked_model.get());
872 
873   flatbuffers::FlatBufferBuilder builder;
874   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
875 
876   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
877       reinterpret_cast<const char*>(builder.GetBufferPointer()),
878       builder.GetSize(), unilib_.get(), calendarlib_.get());
879   ASSERT_TRUE(classifier);
880 
881   AnnotationOptions options;
882   options.is_serialized_entity_data_enabled = false;
883   auto annotations =
884       classifier->Annotate("Barack Obama is 57 years old", options);
885   EXPECT_EQ(1, annotations.size());
886   EXPECT_EQ(1, annotations[0].classification.size());
887   EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
888 
889   // Check entity data.
890   EXPECT_EQ("", annotations[0].classification[0].serialized_entity_data);
891 }
892 
TEST_F(AnnotatorTest,PhoneFiltering)893 TEST_F(AnnotatorTest, PhoneFiltering) {
894   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
895       GetTestModelPath(), unilib_.get(), calendarlib_.get());
896   ASSERT_TRUE(classifier);
897 
898   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
899                          "phone: (123) 456 789", {7, 20})));
900   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
901                          "phone: (123) 456 789,0001112", {7, 25})));
902   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
903                          "phone: (123) 456 789,0001112", {7, 28})));
904 }
905 
TEST_F(AnnotatorTest,SuggestSelection)906 TEST_F(AnnotatorTest, SuggestSelection) {
907   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
908       GetTestModelPath(), unilib_.get(), calendarlib_.get());
909   ASSERT_TRUE(classifier);
910 
911   EXPECT_EQ(classifier->SuggestSelection(
912                 "this afternoon Barack Obama gave a speech at", {15, 21}),
913             CodepointSpan(15, 21));
914 
915   // Try passing whole string.
916   // If more than 1 token is specified, we should return back what entered.
917   EXPECT_EQ(
918       classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
919       CodepointSpan(0, 27));
920 
921   // Single letter.
922   EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), CodepointSpan(0, 1));
923 
924   // Single word.
925   EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), CodepointSpan(0, 4));
926 
927   EXPECT_EQ(
928       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
929       CodepointSpan(11, 23));
930 
931   // Unpaired bracket stripping.
932   EXPECT_EQ(
933       classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}),
934       CodepointSpan(11, 25));
935   EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}),
936             CodepointSpan(12, 15));
937   EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}),
938             CodepointSpan(11, 15));
939   EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}),
940             CodepointSpan(12, 15));
941 
942   // If the resulting selection would be empty, the original span is returned.
943   EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
944             CodepointSpan(11, 13));
945   EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
946             CodepointSpan(11, 12));
947   EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
948             CodepointSpan(11, 12));
949 
950   // If the original span is larger than the found selection, the original span
951   // is returned.
952   EXPECT_EQ(
953       classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}),
954       CodepointSpan(5, 24));
955 }
956 
TEST_F(AnnotatorTest,SuggestSelectionDisabledFail)957 TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
958   const std::string test_model = ReadFile(GetTestModelPath());
959   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
960 
961   // Disable the selection model.
962   unpacked_model->selection_model.clear();
963   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
964   unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
965 
966   flatbuffers::FlatBufferBuilder builder;
967   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
968 
969   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
970       reinterpret_cast<const char*>(builder.GetBufferPointer()),
971       builder.GetSize(), unilib_.get(), calendarlib_.get());
972   // Selection model needs to be present for annotation.
973   ASSERT_FALSE(classifier);
974 }
975 
TEST_F(AnnotatorTest,SuggestSelectionDisabled)976 TEST_F(AnnotatorTest, SuggestSelectionDisabled) {
977   const std::string test_model = ReadFile(GetTestModelPath());
978   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
979 
980   // Disable the selection model.
981   unpacked_model->selection_model.clear();
982   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
983   unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
984   unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
985 
986   // Disable the number annotator. With the selection model disabled, there is
987   // no feature processor, which is required for the number annotator.
988   unpacked_model->number_annotator_options->enabled = false;
989 
990   flatbuffers::FlatBufferBuilder builder;
991   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
992 
993   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
994       reinterpret_cast<const char*>(builder.GetBufferPointer()),
995       builder.GetSize(), unilib_.get(), calendarlib_.get());
996   ASSERT_TRUE(classifier);
997 
998   EXPECT_EQ(
999       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1000       CodepointSpan(11, 14));
1001 
1002   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
1003                          "call me at (800) 123-456 today", {11, 24})));
1004 
1005   EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
1006               IsEmpty());
1007 }
1008 
TEST_F(AnnotatorTest,SuggestSelectionFilteredCollections)1009 TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) {
1010   const std::string test_model = ReadFile(GetTestModelPath());
1011 
1012   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1013       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1014   ASSERT_TRUE(classifier);
1015 
1016   EXPECT_EQ(
1017       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1018       CodepointSpan(11, 23));
1019 
1020   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1021   unpacked_model->output_options.reset(new OutputOptionsT);
1022 
1023   // Disable phone selection
1024   unpacked_model->output_options->filtered_collections_selection.push_back(
1025       "phone");
1026   // We need to force this for filtering.
1027   unpacked_model->selection_options->always_classify_suggested_selection = true;
1028 
1029   flatbuffers::FlatBufferBuilder builder;
1030   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1031 
1032   classifier = Annotator::FromUnownedBuffer(
1033       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1034       builder.GetSize(), unilib_.get(), calendarlib_.get());
1035   ASSERT_TRUE(classifier);
1036 
1037   EXPECT_EQ(
1038       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1039       CodepointSpan(11, 14));
1040 
1041   // Address selection should still work.
1042   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1043             CodepointSpan(0, 27));
1044 }
1045 
TEST_F(AnnotatorTest,SuggestSelectionsAreSymmetric)1046 TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) {
1047   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1048       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1049   ASSERT_TRUE(classifier);
1050 
1051   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
1052             CodepointSpan(0, 27));
1053   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1054             CodepointSpan(0, 27));
1055   EXPECT_EQ(
1056       classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
1057       CodepointSpan(0, 27));
1058   EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
1059                                          {16, 22}),
1060             CodepointSpan(6, 33));
1061 }
1062 
TEST_F(AnnotatorTest,SuggestSelectionWithNewLine)1063 TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) {
1064   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1065       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1066   ASSERT_TRUE(classifier);
1067 
1068   EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
1069             CodepointSpan(4, 16));
1070   EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
1071             CodepointSpan(0, 12));
1072 
1073   SelectionOptions options;
1074   EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
1075             CodepointSpan(0, 12));
1076 }
1077 
TEST_F(AnnotatorTest,SuggestSelectionWithPunctuation)1078 TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) {
1079   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1080       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1081   ASSERT_TRUE(classifier);
1082 
1083   // From the right.
1084   EXPECT_EQ(classifier->SuggestSelection(
1085                 "this afternoon BarackObama, gave a speech at", {15, 26}),
1086             CodepointSpan(15, 26));
1087 
1088   // From the right multiple.
1089   EXPECT_EQ(classifier->SuggestSelection(
1090                 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
1091             CodepointSpan(15, 26));
1092 
1093   // From the left multiple.
1094   EXPECT_EQ(classifier->SuggestSelection(
1095                 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
1096             CodepointSpan(21, 32));
1097 
1098   // From both sides.
1099   EXPECT_EQ(classifier->SuggestSelection(
1100                 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
1101             CodepointSpan(16, 27));
1102 }
1103 
TEST_F(AnnotatorTest,SuggestSelectionNoCrashWithJunk)1104 TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
1105   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1106       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1107   ASSERT_TRUE(classifier);
1108 
1109   // Try passing in bunch of invalid selections.
1110   EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), CodepointSpan(0, 27));
1111   EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
1112             CodepointSpan(-10, 27));
1113   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
1114             CodepointSpan(0, 27));
1115   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
1116             CodepointSpan(-30, 300));
1117   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
1118             CodepointSpan(-10, -1));
1119   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
1120             CodepointSpan(100, 17));
1121 
1122   // Try passing invalid utf8.
1123   EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
1124             CodepointSpan(-1, -1));
1125 }
1126 
TEST_F(AnnotatorTest,SuggestSelectionSelectSpace)1127 TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) {
1128   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1129       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1130   ASSERT_TRUE(classifier);
1131 
1132   EXPECT_EQ(
1133       classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
1134       CodepointSpan(11, 23));
1135   EXPECT_EQ(
1136       classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
1137       CodepointSpan(10, 11));
1138   EXPECT_EQ(
1139       classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
1140       CodepointSpan(23, 24));
1141   EXPECT_EQ(
1142       classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
1143       CodepointSpan(23, 24));
1144   EXPECT_EQ(classifier->SuggestSelection("call me at 857   225 3556, today",
1145                                          {14, 17}),
1146             CodepointSpan(11, 25));
1147   EXPECT_EQ(
1148       classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
1149       CodepointSpan(11, 23));
1150   EXPECT_EQ(
1151       classifier->SuggestSelection(
1152           "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
1153       CodepointSpan(14, 40));
1154   EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
1155             CodepointSpan(4, 5));
1156   EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
1157             CodepointSpan(7, 8));
1158 
1159   // With a punctuation around the selected whitespace.
1160   EXPECT_EQ(
1161       classifier->SuggestSelection(
1162           "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
1163       CodepointSpan(14, 41));
1164 
1165   // When all's whitespace, should return the original indices.
1166   EXPECT_EQ(classifier->SuggestSelection("      ", {0, 1}),
1167             CodepointSpan(0, 1));
1168   EXPECT_EQ(classifier->SuggestSelection("      ", {0, 3}),
1169             CodepointSpan(0, 3));
1170   EXPECT_EQ(classifier->SuggestSelection("      ", {2, 3}),
1171             CodepointSpan(2, 3));
1172   EXPECT_EQ(classifier->SuggestSelection("      ", {5, 6}),
1173             CodepointSpan(5, 6));
1174 }
1175 
TEST_F(AnnotatorTest,SnapLeftIfWhitespaceSelection)1176 TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
1177   UnicodeText text;
1178 
1179   text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
1180   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1181             CodepointSpan(3, 4));
1182   text = UTF8ToUnicodeText("abcd     ", /*do_copy=*/false);
1183   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1184             CodepointSpan(3, 4));
1185 
1186   // Nothing on the left.
1187   text = UTF8ToUnicodeText("     efgh", /*do_copy=*/false);
1188   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1189             CodepointSpan(4, 5));
1190   text = UTF8ToUnicodeText("     efgh", /*do_copy=*/false);
1191   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1192             CodepointSpan(0, 1));
1193 
1194   // Whitespace only.
1195   text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
1196   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, *unilib_),
1197             CodepointSpan(2, 3));
1198   text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
1199   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1200             CodepointSpan(4, 5));
1201   text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
1202   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1203             CodepointSpan(0, 1));
1204 }
1205 
TEST_F(AnnotatorTest,Annotate)1206 TEST_F(AnnotatorTest, Annotate) {
1207   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1208       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1209   ASSERT_TRUE(classifier);
1210 
1211   const std::string test_string =
1212       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1213       "number is 853 225 3556";
1214   EXPECT_THAT(classifier->Annotate(test_string),
1215               ElementsAreArray({
1216                   IsAnnotatedSpan(28, 55, "address"),
1217                   IsAnnotatedSpan(79, 91, "phone"),
1218               }));
1219 
1220   AnnotationOptions options;
1221   EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1222               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1223   EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1224               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1225   // Try passing invalid utf8.
1226   EXPECT_TRUE(
1227       classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
1228           .empty());
1229 }
1230 
TEST_F(AnnotatorTest,AnnotatesWithBracketStripping)1231 TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) {
1232   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1233       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1234   ASSERT_TRUE(classifier);
1235 
1236   EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today"),
1237               ElementsAreArray({
1238                   IsAnnotatedSpan(11, 26, "phone"),
1239               }));
1240 
1241   // Unpaired bracket stripping.
1242   EXPECT_THAT(classifier->Annotate("call me at (07038201818 today"),
1243               ElementsAreArray({
1244                   IsAnnotatedSpan(12, 23, "phone"),
1245               }));
1246   EXPECT_THAT(classifier->Annotate("call me at 07038201818) today"),
1247               ElementsAreArray({
1248                   IsAnnotatedSpan(11, 22, "phone"),
1249               }));
1250   EXPECT_THAT(classifier->Annotate("call me at )07038201818( today"),
1251               ElementsAreArray({
1252                   IsAnnotatedSpan(12, 23, "phone"),
1253               }));
1254 }
1255 
TEST_F(AnnotatorTest,AnnotatesWithBracketStrippingOptimized)1256 TEST_F(AnnotatorTest, AnnotatesWithBracketStrippingOptimized) {
1257   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1258       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1259   ASSERT_TRUE(classifier);
1260 
1261   AnnotationOptions options;
1262   options.enable_optimization = true;
1263 
1264   EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today", options),
1265               ElementsAreArray({
1266                   IsAnnotatedSpan(11, 26, "phone"),
1267               }));
1268 
1269   // Unpaired bracket stripping.
1270   EXPECT_THAT(classifier->Annotate("call me at (07038201818 today", options),
1271               ElementsAreArray({
1272                   IsAnnotatedSpan(12, 23, "phone"),
1273               }));
1274   EXPECT_THAT(classifier->Annotate("call me at 07038201818) today", options),
1275               ElementsAreArray({
1276                   IsAnnotatedSpan(11, 22, "phone"),
1277               }));
1278   EXPECT_THAT(classifier->Annotate("call me at )07038201818( today", options),
1279               ElementsAreArray({
1280                   IsAnnotatedSpan(12, 23, "phone"),
1281               }));
1282 }
1283 
TEST_F(AnnotatorTest,AnnotatesOverlappingNumbers)1284 TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
1285   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1286       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1287   ASSERT_TRUE(classifier);
1288   AnnotationOptions options;
1289   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1290 
1291   // Number, float number and percentage annotator.
1292   EXPECT_THAT(
1293       classifier->Annotate("853 225 3556 and then turn it up 99%, 99 "
1294                            "number, 12345.12345 float number",
1295                            options),
1296       UnorderedElementsAreArray(
1297           {IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(0, 3, "number"),
1298            IsAnnotatedSpan(4, 7, "number"), IsAnnotatedSpan(8, 12, "number"),
1299            IsAnnotatedSpan(33, 35, "number"),
1300            IsAnnotatedSpan(33, 36, "percentage"),
1301            IsAnnotatedSpan(38, 40, "number"), IsAnnotatedSpan(49, 60, "number"),
1302            IsAnnotatedSpan(49, 60, "phone")}));
1303 }
1304 
TEST_F(AnnotatorTest,DoesNotAnnotateNumbersInSmartUsecase)1305 TEST_F(AnnotatorTest, DoesNotAnnotateNumbersInSmartUsecase) {
1306   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1307       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1308   ASSERT_TRUE(classifier);
1309   AnnotationOptions options;
1310   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1311 
1312   EXPECT_THAT(classifier->Annotate(
1313                   "853 225 3556 and then turn it up 99%, 99 number", options),
1314               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"),
1315                                 IsAnnotatedSpan(33, 36, "percentage")}));
1316 }
1317 
VerifyAnnotatesDurationsInRawMode(const Annotator * classifier)1318 void VerifyAnnotatesDurationsInRawMode(const Annotator* classifier) {
1319   ASSERT_TRUE(classifier);
1320   AnnotationOptions options;
1321   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1322 
1323   // Duration annotator.
1324   EXPECT_THAT(classifier->Annotate(
1325                   "it took 9 minutes and 7 seconds to get there", options),
1326               Contains(IsDurationSpan(
1327                   /*start=*/8, /*end=*/31,
1328                   /*duration_ms=*/9 * 60 * 1000 + 7 * 1000)));
1329 }
1330 
TEST_F(AnnotatorTest,AnnotatesDurationsInRawMode)1331 TEST_F(AnnotatorTest, AnnotatesDurationsInRawMode) {
1332   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1333       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1334   VerifyAnnotatesDurationsInRawMode(classifier.get());
1335 }
1336 
VerifyDurationAndRelativeTimeCanOverlapInRawMode(const Annotator * classifier)1337 void VerifyDurationAndRelativeTimeCanOverlapInRawMode(
1338     const Annotator* classifier) {
1339   ASSERT_TRUE(classifier);
1340   AnnotationOptions options;
1341   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1342   options.locales = "en";
1343 
1344   const std::vector<AnnotatedSpan> annotations =
1345       classifier->Annotate("let's meet in 3 hours", options);
1346 
1347   EXPECT_THAT(annotations,
1348               Contains(IsDatetimeSpan(/*start=*/11, /*end=*/21,
1349                                       /*time_ms_utc=*/10800000L,
1350                                       DatetimeGranularity::GRANULARITY_HOUR)));
1351   EXPECT_THAT(annotations,
1352               Contains(IsDurationSpan(/*start=*/14, /*end=*/21,
1353                                       /*duration_ms=*/3 * 60 * 60 * 1000)));
1354 }
1355 
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawMode)1356 TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) {
1357   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1358       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1359   VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1360 }
1361 
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx)1362 TEST_F(AnnotatorTest,
1363        DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx) {
1364   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1365   std::unique_ptr<Annotator> classifier =
1366       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1367                                    unilib_.get(), calendarlib_.get());
1368   VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1369 }
1370 
TEST_F(AnnotatorTest,AnnotateSplitLines)1371 TEST_F(AnnotatorTest, AnnotateSplitLines) {
1372   std::string model_buffer = ReadFile(GetTestModelPath());
1373   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1374     model->selection_feature_options->only_use_line_with_click = true;
1375   });
1376   std::unique_ptr<Annotator> classifier =
1377       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1378                                    unilib_.get(), calendarlib_.get());
1379 
1380   ASSERT_TRUE(classifier);
1381 
1382   const std::string str1 =
1383       "hey, sorry, just finished up. i didn't hear back from you in time.";
1384   const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1385 
1386   const int kAnnotationLength = 26;
1387   EXPECT_THAT(classifier->Annotate(str1), IsEmpty());
1388   EXPECT_THAT(
1389       classifier->Annotate(str2),
1390       ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")}));
1391 
1392   const std::string str3 = str1 + "\n" + str2;
1393   EXPECT_THAT(
1394       classifier->Annotate(str3),
1395       ElementsAreArray({IsAnnotatedSpan(
1396           str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")}));
1397 }
1398 
TEST_F(AnnotatorTest,UsePipeAsNewLineCharacterShouldAnnotateSplitLines)1399 TEST_F(AnnotatorTest, UsePipeAsNewLineCharacterShouldAnnotateSplitLines) {
1400   std::string model_buffer = ReadFile(GetTestModelPath());
1401   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1402     model->selection_feature_options->only_use_line_with_click = true;
1403     model->selection_feature_options->use_pipe_character_for_newline = true;
1404   });
1405   std::unique_ptr<Annotator> classifier =
1406       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1407                                    unilib_.get(), calendarlib_.get());
1408 
1409   ASSERT_TRUE(classifier);
1410 
1411   const std::string str1 = "hey, this is my phone number 853 225 3556";
1412   const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1413   const std::string str3 = str1 + "|" + str2;
1414   const int kAnnotationLengthPhone = 12;
1415   const int kAnnotationLengthAddress = 26;
1416   // Splitting the lines on `str3` should have the same behavior (e.g. find the
1417   // phone and address spans) as if we would annotate `str1` and `str2`
1418   // individually.
1419   const std::vector<AnnotatedSpan>& annotated_spans =
1420       classifier->Annotate(str3);
1421   EXPECT_THAT(annotated_spans,
1422               ElementsAreArray(
1423                   {IsAnnotatedSpan(29, 29 + kAnnotationLengthPhone, "phone"),
1424                    IsAnnotatedSpan(static_cast<int>(str1.size()) + 1,
1425                                    static_cast<int>(str1.size() + 1 +
1426                                                     kAnnotationLengthAddress),
1427                                    "address")}));
1428 }
1429 
TEST_F(AnnotatorTest,NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines)1430 TEST_F(AnnotatorTest,
1431        NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines) {
1432   std::string model_buffer = ReadFile(GetTestModelPath());
1433   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1434     model->selection_feature_options->only_use_line_with_click = true;
1435     model->selection_feature_options->use_pipe_character_for_newline = false;
1436   });
1437   std::unique_ptr<Annotator> classifier =
1438       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1439                                    unilib_.get(), calendarlib_.get());
1440 
1441   ASSERT_TRUE(classifier);
1442 
1443   const std::string str1 = "hey, this is my phone number 853 225 3556";
1444   const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1445   const std::string str3 = str1 + "|" + str2;
1446   const std::vector<AnnotatedSpan>& annotated_spans =
1447       classifier->Annotate(str3);
1448   // Note: We only check that we get a single annotated span here when the '|'
1449   // character is not used to split lines. The reason behind this is that the
1450   // model is not precise for such example and the resulted annotated span might
1451   // change when the model changes.
1452   EXPECT_THAT(annotated_spans.size(), 1);
1453 }
1454 
TEST_F(AnnotatorTest,AnnotateSmallBatches)1455 TEST_F(AnnotatorTest, AnnotateSmallBatches) {
1456   const std::string test_model = ReadFile(GetTestModelPath());
1457   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1458 
1459   // Set the batch size.
1460   unpacked_model->selection_options->batch_size = 4;
1461   flatbuffers::FlatBufferBuilder builder;
1462   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1463 
1464   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1465       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1466       builder.GetSize(), unilib_.get(), calendarlib_.get());
1467   ASSERT_TRUE(classifier);
1468 
1469   const std::string test_string =
1470       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1471       "number is 853 225 3556";
1472   EXPECT_THAT(classifier->Annotate(test_string),
1473               ElementsAreArray({
1474                   IsAnnotatedSpan(28, 55, "address"),
1475                   IsAnnotatedSpan(79, 91, "phone"),
1476               }));
1477 
1478   AnnotationOptions options;
1479   EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1480               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1481   EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1482               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1483 }
1484 
TEST_F(AnnotatorTest,AnnotateFilteringDiscardAll)1485 TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) {
1486   const std::string test_model = ReadFile(GetTestModelPath());
1487   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1488 
1489   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1490   // Add test threshold.
1491   unpacked_model->triggering_options->min_annotate_confidence =
1492       2.f;  // Discards all results.
1493   flatbuffers::FlatBufferBuilder builder;
1494   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1495 
1496   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1497       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1498       builder.GetSize(), unilib_.get(), calendarlib_.get());
1499   ASSERT_TRUE(classifier);
1500 
1501   const std::string test_string =
1502       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1503       "number is 853 225 3556";
1504 
1505   EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
1506 }
1507 
TEST_F(AnnotatorTest,AnnotateFilteringKeepAll)1508 TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) {
1509   const std::string test_model = ReadFile(GetTestModelPath());
1510   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1511 
1512   // Add test thresholds.
1513   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1514   unpacked_model->triggering_options->min_annotate_confidence =
1515       0.f;  // Keeps all results.
1516   unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
1517   flatbuffers::FlatBufferBuilder builder;
1518   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1519 
1520   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1521       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1522       builder.GetSize(), unilib_.get(), calendarlib_.get());
1523   ASSERT_TRUE(classifier);
1524 
1525   const std::string test_string =
1526       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1527       "number is 853 225 3556";
1528   EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
1529 }
1530 
TEST_F(AnnotatorTest,AnnotateDisabled)1531 TEST_F(AnnotatorTest, AnnotateDisabled) {
1532   const std::string test_model = ReadFile(GetTestModelPath());
1533   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1534 
1535   // Disable the model for annotation.
1536   unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
1537   flatbuffers::FlatBufferBuilder builder;
1538   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1539 
1540   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1541       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1542       builder.GetSize(), unilib_.get(), calendarlib_.get());
1543   ASSERT_TRUE(classifier);
1544   const std::string test_string =
1545       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1546       "number is 853 225 3556";
1547   EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
1548 }
1549 
TEST_F(AnnotatorTest,AnnotateFilteredCollections)1550 TEST_F(AnnotatorTest, AnnotateFilteredCollections) {
1551   const std::string test_model = ReadFile(GetTestModelPath());
1552 
1553   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1554       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1555   ASSERT_TRUE(classifier);
1556 
1557   const std::string test_string =
1558       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1559       "number is 853 225 3556";
1560 
1561   EXPECT_THAT(classifier->Annotate(test_string),
1562               ElementsAreArray({
1563                   IsAnnotatedSpan(28, 55, "address"),
1564                   IsAnnotatedSpan(79, 91, "phone"),
1565               }));
1566 
1567   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1568   unpacked_model->output_options.reset(new OutputOptionsT);
1569 
1570   // Disable phone annotation
1571   unpacked_model->output_options->filtered_collections_annotation.push_back(
1572       "phone");
1573 
1574   flatbuffers::FlatBufferBuilder builder;
1575   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1576 
1577   classifier = Annotator::FromUnownedBuffer(
1578       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1579       builder.GetSize(), unilib_.get(), calendarlib_.get());
1580   ASSERT_TRUE(classifier);
1581 
1582   EXPECT_THAT(classifier->Annotate(test_string),
1583               ElementsAreArray({
1584                   IsAnnotatedSpan(28, 55, "address"),
1585               }));
1586 }
1587 
TEST_F(AnnotatorTest,AnnotateFilteredCollectionsSuppress)1588 TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
1589   const std::string test_model = ReadFile(GetTestModelPath());
1590 
1591   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1592       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1593   ASSERT_TRUE(classifier);
1594 
1595   const std::string test_string =
1596       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1597       "number is 853 225 3556";
1598 
1599   EXPECT_THAT(classifier->Annotate(test_string),
1600               ElementsAreArray({
1601                   IsAnnotatedSpan(28, 55, "address"),
1602                   IsAnnotatedSpan(79, 91, "phone"),
1603               }));
1604 
1605   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1606   unpacked_model->output_options.reset(new OutputOptionsT);
1607 
1608   // We add a custom annotator that wins against the phone classification
1609   // below and that we subsequently suppress.
1610   unpacked_model->output_options->filtered_collections_annotation.push_back(
1611       "suppress");
1612 
1613   unpacked_model->regex_model->patterns.push_back(MakePattern(
1614       "suppress", "(\\d{3} ?\\d{4})",
1615       /*enabled_for_classification=*/false,
1616       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
1617 
1618   flatbuffers::FlatBufferBuilder builder;
1619   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1620 
1621   classifier = Annotator::FromUnownedBuffer(
1622       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1623       builder.GetSize(), unilib_.get(), calendarlib_.get());
1624   ASSERT_TRUE(classifier);
1625 
1626   EXPECT_THAT(classifier->Annotate(test_string),
1627               ElementsAreArray({
1628                   IsAnnotatedSpan(28, 55, "address"),
1629               }));
1630 }
1631 
VerifyClassifyTextDateInZurichTimezone(const Annotator * classifier)1632 void VerifyClassifyTextDateInZurichTimezone(const Annotator* classifier) {
1633   EXPECT_TRUE(classifier);
1634   ClassificationOptions options;
1635   options.reference_timezone = "Europe/Zurich";
1636   options.locales = "en";
1637 
1638   std::vector<ClassificationResult> result =
1639       classifier->ClassifyText("january 1, 2017", {0, 15}, options);
1640 
1641   EXPECT_THAT(result,
1642               ElementsAre(IsDateResult(1483225200000,
1643                                        DatetimeGranularity::GRANULARITY_DAY)));
1644 }
1645 
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezone)1646 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) {
1647   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1648       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1649   VerifyClassifyTextDateInZurichTimezone(classifier.get());
1650 }
1651 
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezoneWithDatetimeRegEx)1652 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeRegEx) {
1653   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1654   std::unique_ptr<Annotator> classifier =
1655       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1656                                    unilib_.get(), calendarlib_.get());
1657   VerifyClassifyTextDateInZurichTimezone(classifier.get());
1658 }
1659 
VerifyClassifyTextDateInLATimezone(const Annotator * classifier)1660 void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) {
1661   EXPECT_TRUE(classifier);
1662   ClassificationOptions options;
1663   options.reference_timezone = "America/Los_Angeles";
1664   options.locales = "en";
1665 
1666   std::vector<ClassificationResult> result =
1667       classifier->ClassifyText("march 1, 2017", {0, 13}, options);
1668 
1669   EXPECT_THAT(result,
1670               ElementsAre(IsDateResult(1488355200000,
1671                                        DatetimeGranularity::GRANULARITY_DAY)));
1672 }
1673 
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezoneWithDatetimeRegEx)1674 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeRegEx) {
1675   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1676   std::unique_ptr<Annotator> classifier =
1677       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1678                                    unilib_.get(), calendarlib_.get());
1679   VerifyClassifyTextDateInLATimezone(classifier.get());
1680 }
1681 
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezone)1682 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) {
1683   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1684       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1685   VerifyClassifyTextDateInLATimezone(classifier.get());
1686 }
1687 
VerifyClassifyTextDateOnAotherLine(const Annotator * classifier)1688 void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) {
1689   EXPECT_TRUE(classifier);
1690   ClassificationOptions options;
1691   options.reference_timezone = "Europe/Zurich";
1692   options.locales = "en";
1693 
1694   std::vector<ClassificationResult> result = classifier->ClassifyText(
1695       "hello world this is the first line\n"
1696       "january 1, 2017",
1697       {35, 50}, options);
1698 
1699   EXPECT_THAT(result,
1700               ElementsAre(IsDateResult(1483225200000,
1701                                        DatetimeGranularity::GRANULARITY_DAY)));
1702 }
1703 
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLineWithDatetimeRegEx)1704 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeRegEx) {
1705   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1706   std::unique_ptr<Annotator> classifier =
1707       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1708                                    unilib_.get(), calendarlib_.get());
1709   VerifyClassifyTextDateOnAotherLine(classifier.get());
1710 }
1711 
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLine)1712 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) {
1713   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1714       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1715   VerifyClassifyTextDateOnAotherLine(classifier.get());
1716 }
1717 
VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(const Annotator * classifier)1718 void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(
1719     const Annotator* classifier) {
1720   EXPECT_TRUE(classifier);
1721   std::vector<ClassificationResult> result;
1722   ClassificationOptions options;
1723 
1724   options.reference_timezone = "Europe/Zurich";
1725   options.locales = "en-US";
1726   result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
1727 
1728   // In US, the date should be interpreted as <month>.<day>.
1729   EXPECT_THAT(result,
1730               ElementsAre(IsDatetimeResult(
1731                   5439600000, DatetimeGranularity::GRANULARITY_MINUTE)));
1732 }
1733 
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDay)1734 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) {
1735   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1736       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1737   VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1738 }
1739 
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx)1740 TEST_F(AnnotatorTest,
1741        ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx) {
1742   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1743   std::unique_ptr<Annotator> classifier =
1744       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1745                                    unilib_.get(), calendarlib_.get());
1746   VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1747 }
1748 
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay)1749 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
1750   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1751   std::unique_ptr<Annotator> classifier =
1752       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1753                                    unilib_.get(), calendarlib_.get());
1754   EXPECT_TRUE(classifier);
1755   std::vector<ClassificationResult> result;
1756   ClassificationOptions options;
1757 
1758   options.reference_timezone = "Europe/Zurich";
1759   options.locales = "de";
1760   result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options);
1761 
1762   // In Germany, the date should be interpreted as <day>.<month>.
1763   EXPECT_THAT(result,
1764               ElementsAre(IsDatetimeResult(
1765                   10537200000, DatetimeGranularity::GRANULARITY_MINUTE)));
1766 }
1767 
TEST_F(AnnotatorTest,ClassifyTextAmbiguousDatetime)1768 TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) {
1769   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1770       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1771   EXPECT_TRUE(classifier);
1772   ClassificationOptions options;
1773   options.reference_timezone = "Europe/Zurich";
1774   options.locales = "en-US";
1775   const std::vector<ClassificationResult> result =
1776       classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options);
1777 
1778   EXPECT_THAT(
1779       result,
1780       ElementsAre(
1781           IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1782           IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1783 }
1784 
TEST_F(AnnotatorTest,AnnotateAmbiguousDatetime)1785 TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) {
1786   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1787       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1788   EXPECT_TRUE(classifier);
1789   AnnotationOptions options;
1790   options.reference_timezone = "Europe/Zurich";
1791   options.locales = "en-US";
1792   const std::vector<AnnotatedSpan> spans =
1793       classifier->Annotate("set an alarm for 10:30", options);
1794 
1795   ASSERT_EQ(spans.size(), 1);
1796   const std::vector<ClassificationResult> result = spans[0].classification;
1797   EXPECT_THAT(
1798       result,
1799       ElementsAre(
1800           IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1801           IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1802 }
1803 
TEST_F(AnnotatorTest,SuggestTextDateDisabled)1804 TEST_F(AnnotatorTest, SuggestTextDateDisabled) {
1805   std::string test_model = GetTestModelWithDatetimeRegEx();
1806   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1807 
1808   // Disable the patterns for selection.
1809   for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1810     unpacked_model->datetime_model->patterns[i]->enabled_modes =
1811         ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1812   }
1813   flatbuffers::FlatBufferBuilder builder;
1814   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1815 
1816   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1817       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1818       builder.GetSize(), unilib_.get(), calendarlib_.get());
1819   ASSERT_TRUE(classifier);
1820   EXPECT_EQ("date",
1821             FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1822   EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1823             CodepointSpan(0, 7));
1824   EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1825               ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1826 }
1827 
TEST_F(AnnotatorTest,AnnotatesWithGrammarModel)1828 TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) {
1829   const std::string test_model = ReadFile(GetTestModelPath());
1830   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1831 
1832   // Add test grammar model.
1833   unpacked_model->grammar_model.reset(new GrammarModelT);
1834   GrammarModelT* grammar_model = unpacked_model->grammar_model.get();
1835   grammar_model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
1836   grammar_model->tokenizer_options->tokenization_type = TokenizationType_ICU;
1837   grammar_model->tokenizer_options->icu_preserve_whitespace_tokens = false;
1838   grammar_model->tokenizer_options->tokenize_on_script_change = true;
1839 
1840   // Add test rules.
1841   grammar_model->rules.reset(new grammar::RulesSetT);
1842   grammar::LocaleShardMap locale_shard_map =
1843       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1844   grammar::Rules rules(locale_shard_map);
1845   rules.Add("<tv_detective>", {"jessica", "fletcher"});
1846   rules.Add("<tv_detective>", {"columbo"});
1847   rules.Add("<tv_detective>", {"magnum"});
1848   rules.Add(
1849       "<famous_person>", {"<tv_detective>"},
1850       /*callback=*/
1851       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1852       /*callback_param=*/0 /* rule classification result */);
1853 
1854   // Set result.
1855   grammar_model->rule_classification_result.emplace_back(
1856       new GrammarModel_::RuleClassificationResultT);
1857   GrammarModel_::RuleClassificationResultT* result =
1858       grammar_model->rule_classification_result.back().get();
1859   result->collection_name = "famous person";
1860   result->enabled_modes = ModeFlag_ALL;
1861   rules.Finalize().Serialize(/*include_debug_information=*/false,
1862                              grammar_model->rules.get());
1863   flatbuffers::FlatBufferBuilder builder;
1864   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1865 
1866   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1867       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1868       builder.GetSize(), unilib_.get(), calendarlib_.get());
1869   ASSERT_TRUE(classifier);
1870 
1871   const std::string test_string =
1872       "Did you see the Novel Connection episode where Jessica Fletcher helps "
1873       "Magnum solve the case? I thought that was with Columbo ...";
1874 
1875   EXPECT_THAT(classifier->Annotate(test_string),
1876               ElementsAre(IsAnnotatedSpan(47, 63, "famous person"),
1877                           IsAnnotatedSpan(70, 76, "famous person"),
1878                           IsAnnotatedSpan(117, 124, "famous person")));
1879   EXPECT_THAT(FirstResult(classifier->ClassifyText("Jessica Fletcher",
1880                                                    CodepointSpan{0, 16})),
1881               Eq("famous person"));
1882   EXPECT_THAT(classifier->SuggestSelection("Jessica Fletcher", {0, 7}),
1883               Eq(CodepointSpan{0, 16}));
1884 }
1885 
TEST_F(AnnotatorTest,ResolveConflictsTrivial)1886 TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
1887   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1888 
1889   std::vector<AnnotatedSpan> candidates{
1890       {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
1891   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1892 
1893   BaseOptions options;
1894   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1895   std::vector<int> chosen;
1896   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1897                               locales, options,
1898                               /*interpreter_manager=*/nullptr, &chosen);
1899   EXPECT_THAT(chosen, ElementsAreArray({0}));
1900 }
1901 
TEST_F(AnnotatorTest,ResolveConflictsSequence)1902 TEST_F(AnnotatorTest, ResolveConflictsSequence) {
1903   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1904 
1905   std::vector<AnnotatedSpan> candidates{{
1906       MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1907       MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1908       MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1909       MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1910       MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1911   }};
1912   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1913 
1914   BaseOptions options;
1915   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1916   std::vector<int> chosen;
1917   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1918                               locales, options,
1919                               /*interpreter_manager=*/nullptr, &chosen);
1920   EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1921 }
1922 
TEST_F(AnnotatorTest,ResolveConflictsThreeSpans)1923 TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
1924   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1925 
1926   std::vector<AnnotatedSpan> candidates{{
1927       MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1928       MakeAnnotatedSpan({1, 5}, "phone", 0.5),  // Looser!
1929       MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1930   }};
1931   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1932 
1933   BaseOptions options;
1934   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1935   std::vector<int> chosen;
1936   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1937                               locales, options,
1938                               /*interpreter_manager=*/nullptr, &chosen);
1939   EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1940 }
1941 
TEST_F(AnnotatorTest,ResolveConflictsThreeSpansReversed)1942 TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
1943   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1944 
1945   std::vector<AnnotatedSpan> candidates{{
1946       MakeAnnotatedSpan({0, 3}, "phone", 0.5),  // Looser!
1947       MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1948       MakeAnnotatedSpan({3, 7}, "phone", 0.6),  // Looser!
1949   }};
1950   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1951 
1952   BaseOptions options;
1953   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1954   std::vector<int> chosen;
1955   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1956                               locales, options,
1957                               /*interpreter_manager=*/nullptr, &chosen);
1958   EXPECT_THAT(chosen, ElementsAreArray({1}));
1959 }
1960 
TEST_F(AnnotatorTest,DoesNotPrioritizeLongerSpanWhenDoingConflictResolution)1961 TEST_F(AnnotatorTest, DoesNotPrioritizeLongerSpanWhenDoingConflictResolution) {
1962   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1963 
1964   std::vector<AnnotatedSpan> candidates{{
1965       MakeAnnotatedSpan({3, 7}, "unit", 1),
1966       MakeAnnotatedSpan({5, 13}, "unit", 1),  // Looser!
1967       MakeAnnotatedSpan({5, 30}, "url", 1),   // Looser!
1968       MakeAnnotatedSpan({14, 20}, "email", 1),
1969   }};
1970   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1971 
1972   BaseOptions options;
1973   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1974   std::vector<int> chosen;
1975   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1976                               locales, options,
1977                               /*interpreter_manager=*/nullptr, &chosen);
1978   // Picks the first and the last annotations because they do not overlap.
1979   EXPECT_THAT(chosen, ElementsAreArray({0, 3}));
1980 }
1981 
TEST_F(AnnotatorTest,PrioritizeLongerSpanWhenDoingConflictResolution)1982 TEST_F(AnnotatorTest, PrioritizeLongerSpanWhenDoingConflictResolution) {
1983   const std::string test_model = ReadFile(GetTestModelPath());
1984   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1985   TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
1986   unpacked_model->conflict_resolution_options.reset(
1987       new Model_::ConflictResolutionOptionsT);
1988   unpacked_model->conflict_resolution_options->prioritize_longest_annotation =
1989       true;
1990 
1991   flatbuffers::FlatBufferBuilder builder;
1992   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1993 
1994   std::unique_ptr<TestingAnnotator> classifier =
1995       TestingAnnotator::FromUnownedBuffer(
1996           reinterpret_cast<const char*>(builder.GetBufferPointer()),
1997           builder.GetSize(), unilib_.get(), calendarlib_.get());
1998   TC3_CHECK(classifier != nullptr);
1999 
2000   std::vector<AnnotatedSpan> candidates{{
2001       MakeAnnotatedSpan({3, 7}, "unit", 1),     // Looser!
2002       MakeAnnotatedSpan({5, 13}, "unit", 1),    // Looser!
2003       MakeAnnotatedSpan({5, 30}, "url", 1),     // Pick longest match.
2004       MakeAnnotatedSpan({14, 20}, "email", 1),  // Looser!
2005   }};
2006   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2007 
2008   BaseOptions options;
2009   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
2010   std::vector<int> chosen;
2011   classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2012                                locales, options,
2013                                /*interpreter_manager=*/nullptr, &chosen);
2014   EXPECT_THAT(chosen, ElementsAreArray({2}));
2015 }
2016 
TEST_F(AnnotatorTest,ResolveConflictsFiveSpans)2017 TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
2018   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2019 
2020   std::vector<AnnotatedSpan> candidates{{
2021       MakeAnnotatedSpan({0, 3}, "phone", 0.5),
2022       MakeAnnotatedSpan({1, 5}, "other", 1.0),  // Looser!
2023       MakeAnnotatedSpan({3, 7}, "phone", 0.6),
2024       MakeAnnotatedSpan({8, 12}, "phone", 0.6),  // Looser!
2025       MakeAnnotatedSpan({11, 15}, "phone", 0.9),
2026   }};
2027   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2028 
2029   BaseOptions options;
2030   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
2031   std::vector<int> chosen;
2032   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2033                               locales, options,
2034                               /*interpreter_manager=*/nullptr, &chosen);
2035   EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
2036 }
2037 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst)2038 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) {
2039   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2040 
2041   std::vector<AnnotatedSpan> candidates{{
2042       MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2043                         AnnotatedSpan::Source::KNOWLEDGE),
2044       MakeAnnotatedSpan({5, 10}, "address", 0.6),
2045   }};
2046   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2047 
2048   BaseOptions options;
2049   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2050   std::vector<int> chosen;
2051   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2052                               locales, options,
2053                               /*interpreter_manager=*/nullptr, &chosen);
2054   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2055 }
2056 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond)2057 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) {
2058   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2059 
2060   std::vector<AnnotatedSpan> candidates{{
2061       MakeAnnotatedSpan({0, 15}, "address", 0.7),
2062       MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2063                         AnnotatedSpan::Source::KNOWLEDGE),
2064   }};
2065   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2066 
2067   BaseOptions options;
2068   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2069   std::vector<int> chosen;
2070   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2071                               locales, options,
2072                               /*interpreter_manager=*/nullptr, &chosen);
2073   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2074 }
2075 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedBothKnowledge)2076 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) {
2077   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2078 
2079   std::vector<AnnotatedSpan> candidates{{
2080       MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2081                         AnnotatedSpan::Source::KNOWLEDGE),
2082       MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2083                         AnnotatedSpan::Source::KNOWLEDGE),
2084   }};
2085   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2086 
2087   BaseOptions options;
2088   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2089   std::vector<int> chosen;
2090   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2091                               locales, options,
2092                               /*interpreter_manager=*/nullptr, &chosen);
2093   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2094 }
2095 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsNotAllowed)2096 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) {
2097   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2098 
2099   std::vector<AnnotatedSpan> candidates{{
2100       MakeAnnotatedSpan({0, 15}, "address", 0.7),
2101       MakeAnnotatedSpan({5, 10}, "date", 0.6),
2102   }};
2103   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2104 
2105   BaseOptions options;
2106   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2107   std::vector<int> chosen;
2108   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2109                               locales, options,
2110                               /*interpreter_manager=*/nullptr, &chosen);
2111   EXPECT_THAT(chosen, ElementsAreArray({0}));
2112 }
2113 
TEST_F(AnnotatorTest,ResolveConflictsRawModeGeneralOverlapsAllowed)2114 TEST_F(AnnotatorTest, ResolveConflictsRawModeGeneralOverlapsAllowed) {
2115   TestingAnnotator classifier(
2116       unilib_.get(), calendarlib_.get(), [](ModelT* model) {
2117         model->conflict_resolution_options.reset(
2118             new Model_::ConflictResolutionOptionsT);
2119         model->conflict_resolution_options->do_conflict_resolution_in_raw_mode =
2120             false;
2121       });
2122 
2123   std::vector<AnnotatedSpan> candidates{{
2124       MakeAnnotatedSpan({0, 15}, "address", 0.7),
2125       MakeAnnotatedSpan({5, 10}, "date", 0.6),
2126   }};
2127   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2128 
2129   BaseOptions options;
2130   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2131   std::vector<int> chosen;
2132   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2133                               locales, options,
2134                               /*interpreter_manager=*/nullptr, &chosen);
2135   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2136 }
2137 
VerifyLongInput(const Annotator * classifier)2138 void VerifyLongInput(const Annotator* classifier) {
2139   ASSERT_TRUE(classifier);
2140 
2141   for (const auto& type_value_pair :
2142        std::vector<std::pair<std::string, std::string>>{
2143            {"address", "350 Third Street, Cambridge"},
2144            {"phone", "123 456-7890"},
2145            {"url", "www.google.com"},
2146            {"email", "someone@gmail.com"},
2147            {"flight", "LX 38"},
2148            {"date", "September 1, 2018"}}) {
2149     const std::string input_100k = std::string(50000, ' ') +
2150                                    type_value_pair.second +
2151                                    std::string(50000, ' ');
2152     const int value_length = type_value_pair.second.size();
2153 
2154     AnnotationOptions annotation_options;
2155     annotation_options.locales = "en";
2156     EXPECT_THAT(classifier->Annotate(input_100k, annotation_options),
2157                 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
2158                                                   type_value_pair.first)}));
2159     SelectionOptions selection_options;
2160     selection_options.locales = "en";
2161     EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001},
2162                                            selection_options),
2163               CodepointSpan(50000, 50000 + value_length));
2164 
2165     ClassificationOptions classification_options;
2166     classification_options.locales = "en";
2167     EXPECT_EQ(type_value_pair.first,
2168               FirstResult(classifier->ClassifyText(
2169                   input_100k, {50000, 50000 + value_length},
2170                   classification_options)));
2171   }
2172 }
2173 
TEST_F(AnnotatorTest,LongInput)2174 TEST_F(AnnotatorTest, LongInput) {
2175   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2176       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2177   VerifyLongInput(classifier.get());
2178 }
2179 
TEST_F(AnnotatorTest,LongInputWithRegExDatetime)2180 TEST_F(AnnotatorTest, LongInputWithRegExDatetime) {
2181   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2182   std::unique_ptr<Annotator> classifier =
2183       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2184                                    unilib_.get(), calendarlib_.get());
2185   VerifyLongInput(classifier.get());
2186 }
2187 
2188 // These coarse tests are there only to make sure the execution happens in
2189 // reasonable amount of time.
TEST_F(AnnotatorTest,LongInputNoResultCheck)2190 TEST_F(AnnotatorTest, LongInputNoResultCheck) {
2191   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2192       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2193   ASSERT_TRUE(classifier);
2194 
2195   for (const std::string& value :
2196        std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
2197     const std::string input_100k =
2198         std::string(50000, ' ') + value + std::string(50000, ' ');
2199     const int value_length = value.size();
2200 
2201     classifier->Annotate(input_100k);
2202     classifier->SuggestSelection(input_100k, {50000, 50001});
2203     classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
2204   }
2205 }
2206 
TEST_F(AnnotatorTest,MaxTokenLength)2207 TEST_F(AnnotatorTest, MaxTokenLength) {
2208   const std::string test_model = ReadFile(GetTestModelPath());
2209   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2210 
2211   std::unique_ptr<Annotator> classifier;
2212 
2213   // With unrestricted number of tokens should behave normally.
2214   unpacked_model->classification_options->max_num_tokens = -1;
2215 
2216   flatbuffers::FlatBufferBuilder builder;
2217   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2218   classifier = Annotator::FromUnownedBuffer(
2219       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2220       builder.GetSize(), unilib_.get(), calendarlib_.get());
2221   ASSERT_TRUE(classifier);
2222 
2223   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2224                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2225             "address");
2226 
2227   // Raise the maximum number of tokens to suppress the classification.
2228   unpacked_model->classification_options->max_num_tokens = 3;
2229 
2230   flatbuffers::FlatBufferBuilder builder2;
2231   FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2232   classifier = Annotator::FromUnownedBuffer(
2233       reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2234       builder2.GetSize(), unilib_.get(), calendarlib_.get());
2235   ASSERT_TRUE(classifier);
2236 
2237   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2238                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2239             "other");
2240 }
2241 
TEST_F(AnnotatorTest,MinAddressTokenLength)2242 TEST_F(AnnotatorTest, MinAddressTokenLength) {
2243   const std::string test_model = ReadFile(GetTestModelPath());
2244   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2245 
2246   std::unique_ptr<Annotator> classifier;
2247 
2248   // With unrestricted number of address tokens should behave normally.
2249   unpacked_model->classification_options->address_min_num_tokens = 0;
2250 
2251   flatbuffers::FlatBufferBuilder builder;
2252   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2253   classifier = Annotator::FromUnownedBuffer(
2254       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2255       builder.GetSize(), unilib_.get(), calendarlib_.get());
2256   ASSERT_TRUE(classifier);
2257 
2258   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2259                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2260             "address");
2261 
2262   // Raise number of address tokens to suppress the address classification.
2263   unpacked_model->classification_options->address_min_num_tokens = 5;
2264 
2265   flatbuffers::FlatBufferBuilder builder2;
2266   FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2267   classifier = Annotator::FromUnownedBuffer(
2268       reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2269       builder2.GetSize(), unilib_.get(), calendarlib_.get());
2270   ASSERT_TRUE(classifier);
2271 
2272   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2273                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2274             "other");
2275 }
2276 
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighOtherIsPreferredToFlight)2277 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighOtherIsPreferredToFlight) {
2278   const std::string test_model = ReadFile(GetTestModelPath());
2279   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2280   unpacked_model->triggering_options->other_collection_priority_score = 1.0;
2281 
2282   flatbuffers::FlatBufferBuilder builder;
2283   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2284   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2285       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2286       builder.GetSize(), unilib_.get(), calendarlib_.get());
2287   ASSERT_TRUE(classifier);
2288 
2289   EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "other");
2290 }
2291 
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighFlightIsPreferredToOther)2292 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighFlightIsPreferredToOther) {
2293   const std::string test_model = ReadFile(GetTestModelPath());
2294   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2295   unpacked_model->triggering_options->other_collection_priority_score = -100.0;
2296 
2297   flatbuffers::FlatBufferBuilder builder;
2298   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2299   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2300       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2301       builder.GetSize(), unilib_.get(), calendarlib_.get());
2302   ASSERT_TRUE(classifier);
2303 
2304   EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "flight");
2305 }
2306 
TEST_F(AnnotatorTest,VisitAnnotatorModel)2307 TEST_F(AnnotatorTest, VisitAnnotatorModel) {
2308   EXPECT_TRUE(
2309       VisitAnnotatorModel<bool>(GetTestModelPath(), [](const Model* model) {
2310         if (model == nullptr) {
2311           return false;
2312         }
2313         return true;
2314       }));
2315   EXPECT_FALSE(VisitAnnotatorModel<bool>(
2316       GetModelPath() + "non_existing_model.fb", [](const Model* model) {
2317         if (model == nullptr) {
2318           return false;
2319         }
2320         return true;
2321       }));
2322 }
2323 
TEST_F(AnnotatorTest,TriggersWhenNoLanguageDetected)2324 TEST_F(AnnotatorTest, TriggersWhenNoLanguageDetected) {
2325   std::string model_buffer = ReadFile(GetTestModelPath());
2326   model_buffer = ModifyAnnotatorModel(
2327       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2328   std::unique_ptr<Annotator> classifier =
2329       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2330                                    unilib_.get(), calendarlib_.get());
2331   ASSERT_TRUE(classifier);
2332 
2333   EXPECT_THAT(classifier->Annotate("(555) 225-3556"),
2334               ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2335   EXPECT_EQ("phone",
2336             FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14})));
2337   EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}),
2338             CodepointSpan(0, 14));
2339 }
2340 
TEST_F(AnnotatorTest,AnnotateTriggersWhenSupportedLanguageDetected)2341 TEST_F(AnnotatorTest, AnnotateTriggersWhenSupportedLanguageDetected) {
2342   std::string model_buffer = ReadFile(GetTestModelPath());
2343   model_buffer = ModifyAnnotatorModel(
2344       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2345   std::unique_ptr<Annotator> classifier =
2346       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2347                                    unilib_.get(), calendarlib_.get());
2348   ASSERT_TRUE(classifier);
2349   AnnotationOptions options;
2350   options.detected_text_language_tags = "cs";
2351 
2352   EXPECT_THAT(classifier->Annotate("(555) 225-3556", options),
2353               ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2354 }
2355 
TEST_F(AnnotatorTest,AnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2356 TEST_F(AnnotatorTest, AnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2357   std::string model_buffer = ReadFile(GetTestModelPath());
2358   model_buffer = ModifyAnnotatorModel(
2359       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2360   std::unique_ptr<Annotator> classifier =
2361       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2362                                    unilib_.get(), calendarlib_.get());
2363   ASSERT_TRUE(classifier);
2364   AnnotationOptions options;
2365   options.detected_text_language_tags = "de";
2366 
2367   EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), IsEmpty());
2368 }
2369 
TEST_F(AnnotatorTest,ClassifyTextTriggersWhenSupportedLanguageDetected)2370 TEST_F(AnnotatorTest, ClassifyTextTriggersWhenSupportedLanguageDetected) {
2371   std::string model_buffer = ReadFile(GetTestModelPath());
2372   model_buffer = ModifyAnnotatorModel(
2373       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2374   std::unique_ptr<Annotator> classifier =
2375       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2376                                    unilib_.get(), calendarlib_.get());
2377   ASSERT_TRUE(classifier);
2378   ClassificationOptions options;
2379   options.detected_text_language_tags = "cs";
2380 
2381   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556",
2382                                                           {0, 14}, options)));
2383 }
2384 
TEST_F(AnnotatorTest,ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2385 TEST_F(AnnotatorTest,
2386        ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2387   std::string model_buffer = ReadFile(GetTestModelPath());
2388   model_buffer = ModifyAnnotatorModel(
2389       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2390   std::unique_ptr<Annotator> classifier =
2391       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2392                                    unilib_.get(), calendarlib_.get());
2393   ASSERT_TRUE(classifier);
2394   ClassificationOptions options;
2395   options.detected_text_language_tags = "de";
2396 
2397   EXPECT_THAT(classifier->ClassifyText("(555) 225-3556", {0, 14}, options),
2398               IsEmpty());
2399 }
2400 
TEST_F(AnnotatorTest,SuggestSelectionTriggersWhenSupportedLanguageDetected)2401 TEST_F(AnnotatorTest, SuggestSelectionTriggersWhenSupportedLanguageDetected) {
2402   std::string model_buffer = ReadFile(GetTestModelPath());
2403   model_buffer = ModifyAnnotatorModel(
2404       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2405   std::unique_ptr<Annotator> classifier =
2406       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2407                                    unilib_.get(), calendarlib_.get());
2408   ASSERT_TRUE(classifier);
2409   SelectionOptions options;
2410   options.detected_text_language_tags = "cs";
2411 
2412   EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2413             CodepointSpan(0, 14));
2414 }
2415 
TEST_F(AnnotatorTest,SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2416 TEST_F(AnnotatorTest,
2417        SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2418   std::string model_buffer = ReadFile(GetTestModelPath());
2419   model_buffer = ModifyAnnotatorModel(
2420       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2421   std::unique_ptr<Annotator> classifier =
2422       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2423                                    unilib_.get(), calendarlib_.get());
2424   ASSERT_TRUE(classifier);
2425   SelectionOptions options;
2426   options.detected_text_language_tags = "de";
2427 
2428   EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2429             CodepointSpan(6, 9));
2430 }
2431 
TEST_F(AnnotatorTest,MlModelTriggersWhenNoLanguageDetected)2432 TEST_F(AnnotatorTest, MlModelTriggersWhenNoLanguageDetected) {
2433   std::string model_buffer = ReadFile(GetTestModelPath());
2434   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2435     model->triggering_locales = "en,cs";
2436     model->triggering_options->locales = "en,cs";
2437   });
2438   std::unique_ptr<Annotator> classifier =
2439       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2440                                    unilib_.get(), calendarlib_.get());
2441   ASSERT_TRUE(classifier);
2442 
2443   EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge"),
2444               ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2445   EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2446                            "350 Third Street, Cambridge", {0, 27})));
2447   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
2448             CodepointSpan(0, 27));
2449 }
2450 
TEST_F(AnnotatorTest,MlModelAnnotateTriggersWhenSupportedLanguageDetected)2451 TEST_F(AnnotatorTest, MlModelAnnotateTriggersWhenSupportedLanguageDetected) {
2452   std::string model_buffer = ReadFile(GetTestModelPath());
2453   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2454     model->triggering_locales = "en,cs";
2455     model->triggering_options->locales = "en,cs";
2456   });
2457   std::unique_ptr<Annotator> classifier =
2458       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2459                                    unilib_.get(), calendarlib_.get());
2460   ASSERT_TRUE(classifier);
2461   AnnotationOptions options;
2462   options.detected_text_language_tags = "cs";
2463 
2464   EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2465               ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2466 }
2467 
TEST_F(AnnotatorTest,MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2468 TEST_F(AnnotatorTest,
2469        MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2470   std::string model_buffer = ReadFile(GetTestModelPath());
2471   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2472     model->triggering_locales = "en,cs";
2473     model->triggering_options->locales = "en,cs";
2474   });
2475   std::unique_ptr<Annotator> classifier =
2476       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2477                                    unilib_.get(), calendarlib_.get());
2478   ASSERT_TRUE(classifier);
2479   AnnotationOptions options;
2480   options.detected_text_language_tags = "de";
2481 
2482   EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2483               IsEmpty());
2484 }
2485 
TEST_F(AnnotatorTest,MlModelClassifyTextTriggersWhenSupportedLanguageDetected)2486 TEST_F(AnnotatorTest,
2487        MlModelClassifyTextTriggersWhenSupportedLanguageDetected) {
2488   std::string model_buffer = ReadFile(GetTestModelPath());
2489   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2490     model->triggering_locales = "en,cs";
2491     model->triggering_options->locales = "en,cs";
2492   });
2493   std::unique_ptr<Annotator> classifier =
2494       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2495                                    unilib_.get(), calendarlib_.get());
2496   ASSERT_TRUE(classifier);
2497   ClassificationOptions options;
2498   options.detected_text_language_tags = "cs";
2499 
2500   EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2501                            "350 Third Street, Cambridge", {0, 27}, options)));
2502 }
2503 
TEST_F(AnnotatorTest,MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2504 TEST_F(AnnotatorTest,
2505        MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2506   std::string model_buffer = ReadFile(GetTestModelPath());
2507   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2508     model->triggering_locales = "en,cs";
2509     model->triggering_options->locales = "en,cs";
2510   });
2511   std::unique_ptr<Annotator> classifier =
2512       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2513                                    unilib_.get(), calendarlib_.get());
2514   ASSERT_TRUE(classifier);
2515   ClassificationOptions options;
2516   options.detected_text_language_tags = "de";
2517 
2518   EXPECT_THAT(
2519       classifier->ClassifyText("350 Third Street, Cambridge", {0, 27}, options),
2520       IsEmpty());
2521 }
2522 
TEST_F(AnnotatorTest,MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected)2523 TEST_F(AnnotatorTest,
2524        MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected) {
2525   std::string model_buffer = ReadFile(GetTestModelPath());
2526   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2527     model->triggering_locales = "en,cs";
2528     model->triggering_options->locales = "en,cs";
2529   });
2530   std::unique_ptr<Annotator> classifier =
2531       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2532                                    unilib_.get(), calendarlib_.get());
2533   ASSERT_TRUE(classifier);
2534   SelectionOptions options;
2535   options.detected_text_language_tags = "cs";
2536 
2537   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2538                                          options),
2539             CodepointSpan(0, 27));
2540 }
2541 
TEST_F(AnnotatorTest,MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2542 TEST_F(AnnotatorTest,
2543        MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2544   std::string model_buffer = ReadFile(GetTestModelPath());
2545   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2546     model->triggering_locales = "en,cs";
2547     model->triggering_options->locales = "en,cs";
2548   });
2549   std::unique_ptr<Annotator> classifier =
2550       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2551                                    unilib_.get(), calendarlib_.get());
2552   ASSERT_TRUE(classifier);
2553   SelectionOptions options;
2554   options.detected_text_language_tags = "de";
2555 
2556   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2557                                          options),
2558             CodepointSpan(4, 9));
2559 }
2560 
VerifyClassifyTextOutputsDatetimeEntityData(const Annotator * classifier)2561 void VerifyClassifyTextOutputsDatetimeEntityData(const Annotator* classifier) {
2562   EXPECT_TRUE(classifier);
2563   std::vector<ClassificationResult> result;
2564   ClassificationOptions options;
2565   options.locales = "en-US";
2566 
2567   result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
2568 
2569   ASSERT_GE(result.size(), 0);
2570   const EntityData* entity_data =
2571       GetEntityData(result[0].serialized_entity_data.data());
2572   ASSERT_NE(entity_data, nullptr);
2573   ASSERT_NE(entity_data->datetime(), nullptr);
2574   EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 5443200000L);
2575   EXPECT_EQ(entity_data->datetime()->granularity(),
2576             EntityData_::Datetime_::Granularity_GRANULARITY_MINUTE);
2577   EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 6);
2578 
2579   auto* meridiem = entity_data->datetime()->datetime_component()->Get(0);
2580   EXPECT_EQ(meridiem->component_type(),
2581             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MERIDIEM);
2582   EXPECT_EQ(meridiem->absolute_value(), 0);
2583   EXPECT_EQ(meridiem->relative_count(), 0);
2584   EXPECT_EQ(meridiem->relation_type(),
2585             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2586 
2587   auto* minute = entity_data->datetime()->datetime_component()->Get(1);
2588   EXPECT_EQ(minute->component_type(),
2589             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MINUTE);
2590   EXPECT_EQ(minute->absolute_value(), 0);
2591   EXPECT_EQ(minute->relative_count(), 0);
2592   EXPECT_EQ(minute->relation_type(),
2593             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2594 
2595   auto* hour = entity_data->datetime()->datetime_component()->Get(2);
2596   EXPECT_EQ(hour->component_type(),
2597             EntityData_::Datetime_::DatetimeComponent_::ComponentType_HOUR);
2598   EXPECT_EQ(hour->absolute_value(), 0);
2599   EXPECT_EQ(hour->relative_count(), 0);
2600   EXPECT_EQ(hour->relation_type(),
2601             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2602 
2603   auto* day = entity_data->datetime()->datetime_component()->Get(3);
2604   EXPECT_EQ(
2605       day->component_type(),
2606       EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2607   EXPECT_EQ(day->absolute_value(), 5);
2608   EXPECT_EQ(day->relative_count(), 0);
2609   EXPECT_EQ(day->relation_type(),
2610             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2611 
2612   auto* month = entity_data->datetime()->datetime_component()->Get(4);
2613   EXPECT_EQ(month->component_type(),
2614             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2615   EXPECT_EQ(month->absolute_value(), 3);
2616   EXPECT_EQ(month->relative_count(), 0);
2617   EXPECT_EQ(month->relation_type(),
2618             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2619 
2620   auto* year = entity_data->datetime()->datetime_component()->Get(5);
2621   EXPECT_EQ(year->component_type(),
2622             EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2623   EXPECT_EQ(year->absolute_value(), 1970);
2624   EXPECT_EQ(year->relative_count(), 0);
2625   EXPECT_EQ(year->relation_type(),
2626             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2627 }
2628 
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityData)2629 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityData) {
2630   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2631       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2632   VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2633 }
2634 
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx)2635 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx) {
2636   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2637   std::unique_ptr<Annotator> classifier =
2638       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2639                                    unilib_.get(), calendarlib_.get());
2640   VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2641 }
2642 
VerifyAnnotateOutputsDatetimeEntityData(const Annotator * classifier)2643 void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) {
2644   EXPECT_TRUE(classifier);
2645   std::vector<AnnotatedSpan> result;
2646   AnnotationOptions options;
2647   options.is_serialized_entity_data_enabled = true;
2648   options.locales = "en";
2649 
2650   result = classifier->Annotate("September 1, 2019", options);
2651 
2652   ASSERT_GE(result.size(), 0);
2653   ASSERT_GE(result[0].classification.size(), 0);
2654   ASSERT_EQ(result[0].classification[0].collection, "date");
2655   const EntityData* entity_data =
2656       GetEntityData(result[0].classification[0].serialized_entity_data.data());
2657   ASSERT_NE(entity_data, nullptr);
2658   ASSERT_NE(entity_data->datetime(), nullptr);
2659   EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 1567296000000L);
2660   EXPECT_EQ(entity_data->datetime()->granularity(),
2661             EntityData_::Datetime_::Granularity_GRANULARITY_DAY);
2662   EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 3);
2663 
2664   auto* day = entity_data->datetime()->datetime_component()->Get(0);
2665   EXPECT_EQ(
2666       day->component_type(),
2667       EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2668   EXPECT_EQ(day->absolute_value(), 1);
2669   EXPECT_EQ(day->relative_count(), 0);
2670   EXPECT_EQ(day->relation_type(),
2671             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2672 
2673   auto* month = entity_data->datetime()->datetime_component()->Get(1);
2674   EXPECT_EQ(month->component_type(),
2675             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2676   EXPECT_EQ(month->absolute_value(), 9);
2677   EXPECT_EQ(month->relative_count(), 0);
2678   EXPECT_EQ(month->relation_type(),
2679             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2680 
2681   auto* year = entity_data->datetime()->datetime_component()->Get(2);
2682   EXPECT_EQ(year->component_type(),
2683             EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2684   EXPECT_EQ(year->absolute_value(), 2019);
2685   EXPECT_EQ(year->relative_count(), 0);
2686   EXPECT_EQ(year->relation_type(),
2687             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2688 }
2689 
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityData)2690 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityData) {
2691   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2692       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2693   VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2694 }
2695 
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx)2696 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx) {
2697   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2698   std::unique_ptr<Annotator> classifier =
2699       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2700                                    unilib_.get(), calendarlib_.get());
2701   VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2702 }
2703 
TEST_F(AnnotatorTest,AnnotateOutputsMoneyEntityData)2704 TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) {
2705   // std::string model_buffer = GetTestModelWithDatetimeRegEx();
2706   // std::unique_ptr<Annotator> classifier =
2707   //     Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2708   //                                  unilib_.get(), calendarlib_.get());
2709   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2710       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2711   EXPECT_TRUE(classifier);
2712   AnnotationOptions options;
2713   options.is_serialized_entity_data_enabled = true;
2714 
2715   ExpectFirstEntityIsMoney(classifier->Annotate("3.5 CHF", options), "CHF",
2716                            /*amount=*/"3.5", /*whole_part=*/3,
2717                            /*decimal_part=*/5, /*nanos=*/500000000);
2718   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.5", options), "CHF",
2719                            /*amount=*/"3.5", /*whole_part=*/3,
2720                            /*decimal_part=*/5, /*nanos=*/500000000);
2721   ExpectFirstEntityIsMoney(
2722       classifier->Annotate("For online purchase of CHF 23.00 enter", options),
2723       "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2724       /*nanos=*/0);
2725   ExpectFirstEntityIsMoney(
2726       classifier->Annotate("For online purchase of 23.00 CHF enter", options),
2727       "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2728       /*nanos=*/0);
2729   ExpectFirstEntityIsMoney(classifier->Annotate("4.8198£", options), "£",
2730                            /*amount=*/"4.8198", /*whole_part=*/4,
2731                            /*decimal_part=*/8198, /*nanos=*/819800000);
2732   ExpectFirstEntityIsMoney(classifier->Annotate("£4.8198", options), "£",
2733                            /*amount=*/"4.8198", /*whole_part=*/4,
2734                            /*decimal_part=*/8198, /*nanos=*/819800000);
2735   ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2736                            /*amount=*/"0.0255", /*whole_part=*/0,
2737                            /*decimal_part=*/255, /*nanos=*/25500000);
2738   ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2739                            /*amount=*/"0.0255", /*whole_part=*/0,
2740                            /*decimal_part=*/255, /*nanos=*/25500000);
2741   ExpectFirstEntityIsMoney(
2742       classifier->Annotate("for txn of INR 000.00 at RAZOR-PAY ZOMATO ONLINE "
2743                            "OR on card ending 0000.",
2744                            options),
2745       "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2746       /*nanos=*/0);
2747   ExpectFirstEntityIsMoney(
2748       classifier->Annotate("for txn of 000.00 INR at RAZOR-PAY ZOMATO ONLINE "
2749                            "OR on card ending 0000.",
2750                            options),
2751       "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2752       /*nanos=*/0);
2753 
2754   ExpectFirstEntityIsMoney(classifier->Annotate("35 CHF", options), "CHF",
2755                            /*amount=*/"35",
2756                            /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0);
2757   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 35", options), "CHF",
2758                            /*amount=*/"35", /*whole_part=*/35,
2759                            /*decimal_part=*/0, /*nanos=*/0);
2760   ExpectFirstEntityIsMoney(
2761       classifier->Annotate("and win back up to CHF 150 - with digitec",
2762                            options),
2763       "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2764       /*nanos=*/0);
2765   ExpectFirstEntityIsMoney(
2766       classifier->Annotate("and win back up to 150 CHF - with digitec",
2767                            options),
2768       "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2769       /*nanos=*/0);
2770 
2771   ExpectFirstEntityIsMoney(classifier->Annotate("3.555.333 CHF", options),
2772                            "CHF", /*amount=*/"3.555.333",
2773                            /*whole_part=*/3555333, /*decimal_part=*/0,
2774                            /*nanos=*/0);
2775   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.555.333", options),
2776                            "CHF", /*amount=*/"3.555.333",
2777                            /*whole_part=*/3555333, /*decimal_part=*/0,
2778                            /*nanos=*/0);
2779   ExpectFirstEntityIsMoney(classifier->Annotate("10,000 CHF", options), "CHF",
2780                            /*amount=*/"10,000", /*whole_part=*/10000,
2781                            /*decimal_part=*/0, /*nanos=*/0);
2782   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 10,000", options), "CHF",
2783                            /*amount=*/"10,000", /*whole_part=*/10000,
2784                            /*decimal_part=*/0, /*nanos=*/0);
2785 
2786   ExpectFirstEntityIsMoney(classifier->Annotate("3,555.33 CHF", options), "CHF",
2787                            /*amount=*/"3,555.33", /*whole_part=*/3555,
2788                            /*decimal_part=*/33, /*nanos=*/330000000);
2789   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3,555.33", options), "CHF",
2790                            /*amount=*/"3,555.33", /*whole_part=*/3555,
2791                            /*decimal_part=*/33, /*nanos=*/330000000);
2792   ExpectFirstEntityIsMoney(classifier->Annotate("$3,000.00", options), "$",
2793                            /*amount=*/"3,000.00", /*whole_part=*/3000,
2794                            /*decimal_part=*/0, /*nanos=*/0);
2795   ExpectFirstEntityIsMoney(classifier->Annotate("3,000.00$", options), "$",
2796                            /*amount=*/"3,000.00", /*whole_part=*/3000,
2797                            /*decimal_part=*/0, /*nanos=*/0);
2798 
2799   ExpectFirstEntityIsMoney(classifier->Annotate("1.2 CHF", options), "CHF",
2800                            /*amount=*/"1.2", /*whole_part=*/1,
2801                            /*decimal_part=*/2, /*nanos=*/200000000);
2802   ExpectFirstEntityIsMoney(classifier->Annotate("CHF1.2", options), "CHF",
2803                            /*amount=*/"1.2", /*whole_part=*/1,
2804                            /*decimal_part=*/2, /*nanos=*/200000000);
2805 
2806   ExpectFirstEntityIsMoney(classifier->Annotate("$1.123456789", options), "$",
2807                            /*amount=*/"1.123456789", /*whole_part=*/1,
2808                            /*decimal_part=*/123456789, /*nanos=*/123456789);
2809   ExpectFirstEntityIsMoney(classifier->Annotate("10.01 CHF", options), "CHF",
2810                            /*amount=*/"10.01", /*whole_part=*/10,
2811                            /*decimal_part=*/1, /*nanos=*/10000000);
2812 
2813   ExpectFirstEntityIsMoney(classifier->Annotate("$59 Million", options), "$",
2814                            /*amount=*/"59 million", /*whole_part=*/59000000,
2815                            /*decimal_part=*/0, /*nanos=*/0);
2816   ExpectFirstEntityIsMoney(classifier->Annotate("7.05k €", options), "€",
2817                            /*amount=*/"7.05 k", /*whole_part=*/7050,
2818                            /*decimal_part=*/5, /*nanos=*/0);
2819   ExpectFirstEntityIsMoney(classifier->Annotate("7.123456789m €", options), "€",
2820                            /*amount=*/"7.123456789 m", /*whole_part=*/7123456,
2821                            /*decimal_part=*/123456789, /*nanos=*/789000000);
2822   ExpectFirstEntityIsMoney(classifier->Annotate("7.000056789k €", options), "€",
2823                            /*amount=*/"7.000056789 k", /*whole_part=*/7000,
2824                            /*decimal_part=*/56789, /*nanos=*/56789000);
2825 
2826   ExpectFirstEntityIsMoney(classifier->Annotate("$59.3 Billion", options), "$",
2827                            /*amount=*/"59.3 billion", /*whole_part=*/59,
2828                            /*decimal_part=*/3, /*nanos=*/300000000);
2829   ExpectFirstEntityIsMoney(classifier->Annotate("$1.5 Billion", options), "$",
2830                            /*amount=*/"1.5 billion", /*whole_part=*/1500000000,
2831                            /*decimal_part=*/5, /*nanos=*/0);
2832 }
2833 
TEST_F(AnnotatorTest,TranslateAction)2834 TEST_F(AnnotatorTest, TranslateAction) {
2835   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2836       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2837   std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
2838       libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(GetModelPath() +
2839                                                              "lang_id.smfb");
2840   classifier->SetLangId(langid_model.get());
2841 
2842   ClassificationOptions options;
2843   options.user_familiar_language_tags = "de";
2844 
2845   std::vector<ClassificationResult> classifications =
2846       classifier->ClassifyText("hello, how are you doing?", {11, 14}, options);
2847   EXPECT_EQ(classifications.size(), 1);
2848   EXPECT_EQ(classifications[0].collection, "translate");
2849 }
2850 
TEST_F(AnnotatorTest,AnnotateStructuredInputCallsMultipleAnnotators)2851 TEST_F(AnnotatorTest, AnnotateStructuredInputCallsMultipleAnnotators) {
2852   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2853       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2854 
2855   std::vector<InputFragment> string_fragments = {
2856       {.text = "He owes me 3.5 CHF."},
2857       {.text = "...was born on 13/12/1989."},
2858   };
2859 
2860   AnnotationOptions annotation_options;
2861   annotation_options.locales = "en";
2862   StatusOr<Annotations> annotations_status =
2863       classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2864   ASSERT_TRUE(annotations_status.ok());
2865   Annotations annotations = annotations_status.ValueOrDie();
2866   ASSERT_EQ(annotations.annotated_spans.size(), 2);
2867   EXPECT_THAT(annotations.annotated_spans[0],
2868               ElementsAreArray({IsAnnotatedSpan(11, 18, "money")}));
2869   EXPECT_THAT(annotations.annotated_spans[1],
2870               ElementsAreArray({IsAnnotatedSpan(15, 25, "date")}));
2871 }
2872 
VerifyInputFragmentTimestampOverridesAnnotationOptions(const Annotator * classifier)2873 void VerifyInputFragmentTimestampOverridesAnnotationOptions(
2874     const Annotator* classifier) {
2875   AnnotationOptions annotation_options;
2876   annotation_options.locales = "en";
2877   annotation_options.reference_time_ms_utc =
2878       1554465190000;                             // 04/05/2019 11:53 am
2879   int64 fragment_reference_time = 946727580000;  // 01/01/2000 11:53 am
2880   std::vector<InputFragment> string_fragments = {
2881       {.text = "New event at 17:20"},
2882       {
2883           .text = "New event at 17:20",
2884           .datetime_options = Optional<DatetimeOptions>(
2885               {.reference_time_ms_utc = fragment_reference_time}),
2886       }};
2887   StatusOr<Annotations> annotations_status =
2888       classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2889   ASSERT_TRUE(annotations_status.ok());
2890   Annotations annotations = annotations_status.ValueOrDie();
2891   ASSERT_EQ(annotations.annotated_spans.size(), 2);
2892   EXPECT_THAT(annotations.annotated_spans[0],
2893               ElementsAreArray({IsDatetimeSpan(
2894                   /*start=*/13, /*end=*/18, /*time_ms_utc=*/1554484800000,
2895                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2896   EXPECT_THAT(annotations.annotated_spans[1],
2897               ElementsAreArray({IsDatetimeSpan(
2898                   /*start=*/13, /*end=*/18, /*time_ms_utc=*/946747200000,
2899                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2900 }
2901 
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx)2902 TEST_F(AnnotatorTest,
2903        InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx) {
2904   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2905   std::unique_ptr<Annotator> classifier =
2906       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2907                                    unilib_.get(), calendarlib_.get());
2908   VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2909 }
2910 
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptions)2911 TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptions) {
2912   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2913       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2914   VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2915 }
2916 
VerifyInputFragmentTimezoneOverridesAnnotationOptions(const Annotator * classifier)2917 void VerifyInputFragmentTimezoneOverridesAnnotationOptions(
2918     const Annotator* classifier) {
2919   std::vector<InputFragment> string_fragments = {
2920       {.text = "11/12/2020 17:20"},
2921       {
2922           .text = "11/12/2020 17:20",
2923           .datetime_options = Optional<DatetimeOptions>(
2924               {.reference_timezone = "Europe/Zurich"}),
2925       }};
2926   AnnotationOptions annotation_options;
2927   annotation_options.locales = "en-US";
2928   StatusOr<Annotations> annotations_status =
2929       classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2930   ASSERT_TRUE(annotations_status.ok());
2931   Annotations annotations = annotations_status.ValueOrDie();
2932   ASSERT_EQ(annotations.annotated_spans.size(), 2);
2933   EXPECT_THAT(annotations.annotated_spans[0],
2934               ElementsAreArray({IsDatetimeSpan(
2935                   /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605201600000,
2936                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2937   EXPECT_THAT(annotations.annotated_spans[1],
2938               ElementsAreArray({IsDatetimeSpan(
2939                   /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605198000000,
2940                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2941 }
2942 
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptions)2943 TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptions) {
2944   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2945       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2946   VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2947 }
2948 
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx)2949 TEST_F(AnnotatorTest,
2950        InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx) {
2951   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2952   std::unique_ptr<Annotator> classifier =
2953       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2954                                    unilib_.get(), calendarlib_.get());
2955   VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2956 }
2957 
2958 namespace {
AddDummyRegexDatetimeModel(ModelT * unpacked_model)2959 void AddDummyRegexDatetimeModel(ModelT* unpacked_model) {
2960   unpacked_model->datetime_model.reset(new DatetimeModelT);
2961   // This needs to be false otherwise we'd have to define some extractor. When
2962   // this is false, the 0-th capturing group (whole match) from the pattern is
2963   // used to come up with the indices.
2964   unpacked_model->datetime_model->use_extractors_for_locating = false;
2965   unpacked_model->datetime_model->locales.push_back("en-US");
2966   unpacked_model->datetime_model->default_locales.push_back(0);  // en-US
2967   unpacked_model->datetime_model->patterns.push_back(
2968       std::unique_ptr<DatetimeModelPatternT>(new DatetimeModelPatternT));
2969   unpacked_model->datetime_model->patterns.back()->locales.push_back(
2970       0);  // en-US
2971   unpacked_model->datetime_model->patterns.back()->regexes.push_back(
2972       std::unique_ptr<DatetimeModelPattern_::RegexT>(
2973           new DatetimeModelPattern_::RegexT));
2974   unpacked_model->datetime_model->patterns.back()->regexes.back()->pattern =
2975       "THIS_MATCHES_IN_REGEX_MODEL";
2976   unpacked_model->datetime_model->patterns.back()
2977       ->regexes.back()
2978       ->groups.push_back(DatetimeGroupType_GROUP_UNUSED);
2979 }
2980 }  // namespace
2981 
TEST_F(AnnotatorTest,AnnotateFiltersOutExactDuplicates)2982 TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) {
2983   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2984       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2985   ASSERT_TRUE(classifier);
2986 
2987   // This test assumes that both ML model and Regex model trigger on the
2988   // following text and output "phone" annotation for it.
2989   const std::string test_string = "1000000000";
2990   AnnotationOptions options;
2991   options.annotation_usecase = ANNOTATION_USECASE_RAW;
2992   int num_phones = 0;
2993   for (const AnnotatedSpan& span : classifier->Annotate(test_string, options)) {
2994     if (span.classification[0].collection == "phone") {
2995       num_phones++;
2996     }
2997   }
2998 
2999   EXPECT_EQ(num_phones, 1);
3000 }
3001 
3002 // This test tests the optimizations in Annotator, which make some of the
3003 // annotators not run in the RAW mode when not requested. We test here that the
3004 // results indeed don't contain such annotations. However, this is a bick hacky,
3005 // since one could also add post-filtering, in which case these tests would
3006 // trivially pass.
TEST_F(AnnotatorTest,RawModeOptimizationWorks)3007 TEST_F(AnnotatorTest, RawModeOptimizationWorks) {
3008   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
3009       GetTestModelPath(), unilib_.get(), calendarlib_.get());
3010   ASSERT_TRUE(classifier);
3011 
3012   AnnotationOptions options;
3013   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3014   // Requesting a non-existing type to avoid overlap with existing types.
3015   options.entity_types.insert("some_unknown_entity_type");
3016 
3017   // Normally, the following command would produce the following annotations:
3018   //   Span(19, 24, date, 1.000000),
3019   //   Span(53, 56, number, 1.000000),
3020   //   Span(53, 80, address, 1.000000),
3021   //   Span(128, 142, phone, 1.000000),
3022   //   Span(129, 132, number, 1.000000),
3023   //   Span(192, 200, phone, 1.000000),
3024   //   Span(192, 206, datetime, 1.000000),
3025   //   Span(246, 253, number, 1.000000),
3026   //   Span(246, 253, phone, 1.000000),
3027   //   Span(292, 293, number, 1.000000),
3028   //   Span(292, 301, duration, 1.000000) }
3029   // But because of the optimizations, it doesn't produce anything, since
3030   // we didn't request any of these entities.
3031   EXPECT_THAT(classifier->Annotate(R"--(I saw Barack Obama today
3032                             350 Third Street, Cambridge
3033                             my phone number is (853) 225-3556
3034                             this is when we met: 1.9.2021 13:00
3035                             my number: 1234567
3036                             duration: 3 minutes
3037                             )--",
3038                                    options),
3039               IsEmpty());
3040 }
3041 
VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(const Annotator * classifier)3042 void VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(
3043     const Annotator* classifier) {
3044   ASSERT_TRUE(classifier);
3045   struct Example {
3046     std::string collection;
3047     std::string text;
3048   };
3049 
3050   // These examples contain one example per annotator, to check that each of
3051   // the annotators can work in the RAW mode on its own.
3052   //
3053   // WARNING: This list doesn't contain yet entries for the app, contact, and
3054   // person annotators. Hopefully this won't be needed once b/155214735 is
3055   // fixed and the piping shared across annotators.
3056   std::vector<Example> examples{
3057       // ML Model.
3058       {.collection = Collections::Address(),
3059        .text = "... 350 Third Street, Cambridge ..."},
3060       // Datetime annotator.
3061       {.collection = Collections::DateTime(), .text = "... 1.9.2020 10:00 ..."},
3062       // Duration annotator.
3063       {.collection = Collections::Duration(),
3064        .text = "... 3 hours and 9 seconds ..."},
3065       // Regex annotator.
3066       {.collection = Collections::Email(),
3067        .text = "... platypus@theanimal.org ..."},
3068       // Number annotator.
3069       {.collection = Collections::Number(), .text = "... 100 ..."},
3070   };
3071 
3072   for (const Example& example : examples) {
3073     AnnotationOptions options;
3074     options.locales = "en";
3075     options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3076     options.entity_types.insert(example.collection);
3077 
3078     EXPECT_THAT(classifier->Annotate(example.text, options),
3079                 Contains(IsAnnotationWithType(example.collection)))
3080         << " text: '" << example.text
3081         << "', collection: " << example.collection;
3082   }
3083 }
3084 
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawMode)3085 TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawMode) {
3086   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
3087       GetTestModelPath(), unilib_.get(), calendarlib_.get());
3088   VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3089 }
3090 
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx)3091 TEST_F(AnnotatorTest,
3092        AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx) {
3093   std::string model_buffer = GetTestModelWithDatetimeRegEx();
3094   std::unique_ptr<Annotator> classifier =
3095       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
3096                                    unilib_.get(), calendarlib_.get());
3097   VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3098 }
3099 
TEST_F(AnnotatorTest,InitializeFromString)3100 TEST_F(AnnotatorTest, InitializeFromString) {
3101   const std::string test_model = ReadFile(GetTestModelPath());
3102   std::unique_ptr<Annotator> classifier =
3103       Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3104   ASSERT_TRUE(classifier);
3105   EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty()));
3106 }
3107 
3108 // Regression test for cl/338280366. Enabling only_use_line_with_click had
3109 // the effect, that some annotators in the previous code releases would
3110 // receive only the last line of the input text. This test has the entity on the
3111 // first line (duration).
TEST_F(AnnotatorTest,RegressionTestOnlyUseLineWithClickLastLine)3112 TEST_F(AnnotatorTest, RegressionTestOnlyUseLineWithClickLastLine) {
3113   const std::string test_model = ReadFile(GetTestModelPath());
3114   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
3115 
3116   std::unique_ptr<Annotator> classifier;
3117 
3118   // With unrestricted number of tokens should behave normally.
3119   unpacked_model->selection_feature_options->only_use_line_with_click = true;
3120 
3121   flatbuffers::FlatBufferBuilder builder;
3122   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
3123   classifier = Annotator::FromUnownedBuffer(
3124       reinterpret_cast<const char*>(builder.GetBufferPointer()),
3125       builder.GetSize(), unilib_.get(), calendarlib_.get());
3126   ASSERT_TRUE(classifier);
3127 
3128   AnnotationOptions options;
3129   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3130 
3131   const std::vector<AnnotatedSpan> annotations =
3132       classifier->Annotate("let's meet in 3 hours\nbut not now", options);
3133 
3134   EXPECT_THAT(annotations, Contains(IsDurationSpan(
3135                                /*start=*/14, /*end=*/21,
3136                                /*duration_ms=*/3 * 60 * 60 * 1000)));
3137 }
3138 
TEST_F(AnnotatorTest,DoesntProcessInvalidUtf8)3139 TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) {
3140   const std::string test_model = ReadFile(GetTestModelPath());
3141   const std::string invalid_utf8_text_with_phone_number =
3142       "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80";
3143 
3144   std::unique_ptr<Annotator> classifier =
3145       Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3146   ASSERT_TRUE(classifier);
3147   EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number),
3148               IsEmpty());
3149   EXPECT_THAT(
3150       classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}),
3151       Eq(CodepointSpan{1, 4}));
3152   EXPECT_THAT(
3153       classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}),
3154       IsEmpty());
3155 }
3156 
3157 }  // namespace test_internal
3158 }  // namespace libtextclassifier3
3159