1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/grammar/grammar-annotator.h"
18 
19 #include <memory>
20 
21 #include "annotator/grammar/test-utils.h"
22 #include "annotator/grammar/utils.h"
23 #include "annotator/model_generated.h"
24 #include "utils/flatbuffers/flatbuffers.h"
25 #include "utils/flatbuffers/mutable.h"
26 #include "utils/grammar/utils/locale-shard-map.h"
27 #include "utils/grammar/utils/rules.h"
28 #include "utils/tokenizer.h"
29 #include "utils/utf8/unicodetext.h"
30 #include "gmock/gmock.h"
31 #include "gtest/gtest.h"
32 
33 namespace libtextclassifier3 {
34 namespace {
35 
36 using testing::ElementsAre;
37 
PackModel(const GrammarModelT & model)38 flatbuffers::DetachedBuffer PackModel(const GrammarModelT& model) {
39   flatbuffers::FlatBufferBuilder builder;
40   builder.Finish(GrammarModel::Pack(builder, &model));
41   return builder.Release();
42 }
43 
TEST_F(GrammarAnnotatorTest,AnnotesWithGrammarRules)44 TEST_F(GrammarAnnotatorTest, AnnotesWithGrammarRules) {
45   // Create test rules.
46   GrammarModelT grammar_model;
47   SetTestTokenizerOptions(&grammar_model);
48   grammar_model.rules.reset(new grammar::RulesSetT);
49   grammar::LocaleShardMap locale_shard_map =
50       grammar::LocaleShardMap::CreateLocaleShardMap({""});
51   grammar::Rules rules(locale_shard_map);
52   rules.Add("<carrier>", {"lx"});
53   rules.Add("<carrier>", {"aa"});
54   rules.Add("<flight_code>", {"<2_digits>"});
55   rules.Add("<flight_code>", {"<3_digits>"});
56   rules.Add("<flight_code>", {"<4_digits>"});
57   rules.Add(
58       "<flight>", {"<carrier>", "<flight_code>"},
59       /*callback=*/
60       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
61       /*callback_param=*/
62       AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
63   rules.Finalize().Serialize(/*include_debug_information=*/false,
64                              grammar_model.rules.get());
65   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
66   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
67 
68   std::vector<AnnotatedSpan> result;
69   EXPECT_TRUE(annotator.Annotate(
70       {Locale::FromBCP47("en")},
71       UTF8ToUnicodeText(
72           "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
73           /*do_copy=*/false),
74       &result));
75 
76   EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
77                                   IsAnnotatedSpan(51, 57, "flight")));
78 }
79 
TEST_F(GrammarAnnotatorTest,HandlesAssertions)80 TEST_F(GrammarAnnotatorTest, HandlesAssertions) {
81   // Create test rules.
82   GrammarModelT grammar_model;
83   SetTestTokenizerOptions(&grammar_model);
84   grammar_model.rules.reset(new grammar::RulesSetT);
85   grammar::LocaleShardMap locale_shard_map =
86       grammar::LocaleShardMap::CreateLocaleShardMap({""});
87   grammar::Rules rules(locale_shard_map);
88   rules.Add("<carrier>", {"lx"});
89   rules.Add("<carrier>", {"aa"});
90   rules.Add("<flight_code>", {"<2_digits>"});
91   rules.Add("<flight_code>", {"<3_digits>"});
92   rules.Add("<flight_code>", {"<4_digits>"});
93 
94   // Flight: carrier + flight code and check right context.
95   rules.Add(
96       "<flight>", {"<carrier>", "<flight_code>", "<context_assertion>?"},
97       /*callback=*/
98       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
99       /*callback_param=*/
100       AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
101 
102   // Exclude matches like: LX 38.00 etc.
103   rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
104                      /*negative=*/true);
105 
106   rules.Finalize().Serialize(/*include_debug_information=*/false,
107                              grammar_model.rules.get());
108   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
109   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
110 
111   std::vector<AnnotatedSpan> result;
112   EXPECT_TRUE(annotator.Annotate(
113       {Locale::FromBCP47("en")},
114       UTF8ToUnicodeText("My flight: LX 38 arriving at 4pm, I'll fly back on "
115                         "AA2014 on LX 38.00",
116                         /*do_copy=*/false),
117       &result));
118 
119   EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
120                                   IsAnnotatedSpan(51, 57, "flight")));
121 }
122 
TEST_F(GrammarAnnotatorTest,HandlesCapturingGroups)123 TEST_F(GrammarAnnotatorTest, HandlesCapturingGroups) {
124   // Create test rules.
125   GrammarModelT grammar_model;
126   SetTestTokenizerOptions(&grammar_model);
127   grammar_model.rules.reset(new grammar::RulesSetT);
128   grammar::LocaleShardMap locale_shard_map =
129       grammar::LocaleShardMap::CreateLocaleShardMap({""});
130   grammar::Rules rules(locale_shard_map);
131   rules.AddValueMapping("<low_confidence_phone>", {"<digits>"},
132                         /*value=*/0);
133 
134   // Create rule result.
135   const int classification_result_id =
136       AddRuleClassificationResult("phone", ModeFlag_ALL, 1.0, &grammar_model);
137   grammar_model.rule_classification_result[classification_result_id]
138       ->capturing_group.emplace_back(new CapturingGroupT);
139   grammar_model.rule_classification_result[classification_result_id]
140       ->capturing_group.back()
141       ->extend_selection = true;
142 
143   rules.Add(
144       "<phone>", {"please", "call", "<low_confidence_phone>"},
145       /*callback=*/
146       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
147       /*callback_param=*/classification_result_id);
148 
149   rules.Finalize().Serialize(/*include_debug_information=*/false,
150                              grammar_model.rules.get());
151   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
152   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
153 
154   std::vector<AnnotatedSpan> result;
155   EXPECT_TRUE(annotator.Annotate(
156       {Locale::FromBCP47("en")},
157       UTF8ToUnicodeText("Please call 911 before 10 am!", /*do_copy=*/false),
158       &result));
159   EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(12, 15, "phone")));
160 }
161 
TEST_F(GrammarAnnotatorTest,ClassifiesTextWithGrammarRules)162 TEST_F(GrammarAnnotatorTest, ClassifiesTextWithGrammarRules) {
163   // Create test rules.
164   GrammarModelT grammar_model;
165   SetTestTokenizerOptions(&grammar_model);
166   grammar_model.rules.reset(new grammar::RulesSetT);
167   grammar::LocaleShardMap locale_shard_map =
168       grammar::LocaleShardMap::CreateLocaleShardMap({""});
169   grammar::Rules rules(locale_shard_map);
170   rules.Add("<carrier>", {"lx"});
171   rules.Add("<carrier>", {"aa"});
172   rules.Add("<flight_code>", {"<2_digits>"});
173   rules.Add("<flight_code>", {"<3_digits>"});
174   rules.Add("<flight_code>", {"<4_digits>"});
175   rules.Add(
176       "<flight>", {"<carrier>", "<flight_code>"},
177       /*callback=*/
178       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
179       /*callback_param=*/
180       AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
181   rules.Finalize().Serialize(/*include_debug_information=*/false,
182                              grammar_model.rules.get());
183   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
184   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
185 
186   ClassificationResult result;
187   EXPECT_TRUE(annotator.ClassifyText(
188       {Locale::FromBCP47("en")},
189       UTF8ToUnicodeText(
190           "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
191           /*do_copy=*/false),
192       CodepointSpan{11, 16}, &result));
193   EXPECT_THAT(result, IsClassificationResult("flight"));
194 }
195 
TEST_F(GrammarAnnotatorTest,ClassifiesTextWithAssertions)196 TEST_F(GrammarAnnotatorTest, ClassifiesTextWithAssertions) {
197   // Create test rules.
198   GrammarModelT grammar_model;
199   SetTestTokenizerOptions(&grammar_model);
200   grammar_model.rules.reset(new grammar::RulesSetT);
201 
202   // Use unbounded context.
203   grammar_model.context_left_num_tokens = -1;
204   grammar_model.context_right_num_tokens = -1;
205 
206   grammar::LocaleShardMap locale_shard_map =
207       grammar::LocaleShardMap::CreateLocaleShardMap({""});
208   grammar::Rules rules(locale_shard_map);
209   rules.Add("<carrier>", {"lx"});
210   rules.Add("<carrier>", {"aa"});
211   rules.Add("<flight_code>", {"<2_digits>"});
212   rules.Add("<flight_code>", {"<3_digits>"});
213   rules.Add("<flight_code>", {"<4_digits>"});
214   rules.AddValueMapping("<flight_selection>", {"<carrier>", "<flight_code>"},
215                         /*value=*/0);
216 
217   // Flight: carrier + flight code and check right context.
218   const int classification_result_id =
219       AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model);
220   rules.Add(
221       "<flight>", {"<flight_selection>", "<context_assertion>?"},
222       /*callback=*/
223       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
224       /*callback_param=*/
225       classification_result_id);
226 
227   grammar_model.rule_classification_result[classification_result_id]
228       ->capturing_group.emplace_back(new CapturingGroupT);
229   grammar_model.rule_classification_result[classification_result_id]
230       ->capturing_group.back()
231       ->extend_selection = true;
232 
233   // Exclude matches like: LX 38.00 etc.
234   rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
235                      /*negative=*/true);
236 
237   rules.Finalize().Serialize(/*include_debug_information=*/false,
238                              grammar_model.rules.get());
239   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
240   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
241 
242   EXPECT_FALSE(annotator.ClassifyText(
243       {Locale::FromBCP47("en")},
244       UTF8ToUnicodeText("See LX 38.00", /*do_copy=*/false), CodepointSpan{4, 9},
245       nullptr));
246   EXPECT_FALSE(annotator.ClassifyText(
247       {Locale::FromBCP47("en")},
248       UTF8ToUnicodeText("See LX 38 00", /*do_copy=*/false), CodepointSpan{4, 9},
249       nullptr));
250   ClassificationResult result;
251   EXPECT_TRUE(annotator.ClassifyText(
252       {Locale::FromBCP47("en")},
253       UTF8ToUnicodeText("See LX 38, seat 5", /*do_copy=*/false),
254       CodepointSpan{4, 9}, &result));
255   EXPECT_THAT(result, IsClassificationResult("flight"));
256 }
257 
TEST_F(GrammarAnnotatorTest,ClassifiesTextWithContext)258 TEST_F(GrammarAnnotatorTest, ClassifiesTextWithContext) {
259   // Create test rules.
260   GrammarModelT grammar_model;
261   SetTestTokenizerOptions(&grammar_model);
262   grammar_model.rules.reset(new grammar::RulesSetT);
263 
264   // Max three tokens to the left ("tracking number: ...").
265   grammar_model.context_left_num_tokens = 3;
266   grammar_model.context_right_num_tokens = 0;
267 
268   grammar::LocaleShardMap locale_shard_map =
269       grammar::LocaleShardMap::CreateLocaleShardMap({""});
270   grammar::Rules rules(locale_shard_map);
271   rules.Add("<tracking_number>", {"<5_digits>"});
272   rules.Add("<tracking_number>", {"<6_digits>"});
273   rules.Add("<tracking_number>", {"<7_digits>"});
274   rules.Add("<tracking_number>", {"<8_digits>"});
275   rules.Add("<tracking_number>", {"<9_digits>"});
276   rules.Add("<tracking_number>", {"<10_digits>"});
277   rules.AddValueMapping("<captured_tracking_number>", {"<tracking_number>"},
278                         /*value=*/0);
279   rules.Add("<parcel_tracking_trigger>", {"tracking", "number?", ":?"});
280 
281   const int classification_result_id = AddRuleClassificationResult(
282       "parcel_tracking", ModeFlag_ALL, 1.0, &grammar_model);
283   rules.Add(
284       "<parcel_tracking>",
285       {"<parcel_tracking_trigger>", "<captured_tracking_number>"},
286       /*callback=*/
287       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
288       /*callback_param=*/
289       classification_result_id);
290 
291   grammar_model.rule_classification_result[classification_result_id]
292       ->capturing_group.emplace_back(new CapturingGroupT);
293   grammar_model.rule_classification_result[classification_result_id]
294       ->capturing_group.back()
295       ->extend_selection = true;
296 
297   rules.Finalize().Serialize(/*include_debug_information=*/false,
298                              grammar_model.rules.get());
299   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
300   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
301 
302   ClassificationResult result;
303   EXPECT_TRUE(annotator.ClassifyText(
304       {Locale::FromBCP47("en")},
305       UTF8ToUnicodeText("Use tracking number 012345 for live parcel tracking.",
306                         /*do_copy=*/false),
307       CodepointSpan{20, 26}, &result));
308   EXPECT_THAT(result, IsClassificationResult("parcel_tracking"));
309 
310   EXPECT_FALSE(annotator.ClassifyText(
311       {Locale::FromBCP47("en")},
312       UTF8ToUnicodeText("Call phone 012345 for live parcel tracking.",
313                         /*do_copy=*/false),
314       CodepointSpan{11, 17}, &result));
315 }
316 
TEST_F(GrammarAnnotatorTest,SuggestsTextSelection)317 TEST_F(GrammarAnnotatorTest, SuggestsTextSelection) {
318   // Create test rules.
319   GrammarModelT grammar_model;
320   SetTestTokenizerOptions(&grammar_model);
321   grammar_model.rules.reset(new grammar::RulesSetT);
322   grammar::LocaleShardMap locale_shard_map =
323       grammar::LocaleShardMap::CreateLocaleShardMap({""});
324   grammar::Rules rules(locale_shard_map);
325   rules.Add("<carrier>", {"lx"});
326   rules.Add("<carrier>", {"aa"});
327   rules.Add("<flight_code>", {"<2_digits>"});
328   rules.Add("<flight_code>", {"<3_digits>"});
329   rules.Add("<flight_code>", {"<4_digits>"});
330   rules.Add(
331       "<flight>", {"<carrier>", "<flight_code>"},
332       /*callback=*/
333       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
334       /*callback_param=*/
335       AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
336   rules.Finalize().Serialize(/*include_debug_information=*/false,
337                              grammar_model.rules.get());
338   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
339   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
340 
341   AnnotatedSpan selection;
342   EXPECT_TRUE(annotator.SuggestSelection(
343       {Locale::FromBCP47("en")},
344       UTF8ToUnicodeText(
345           "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
346           /*do_copy=*/false),
347       /*selection=*/CodepointSpan{14, 15}, &selection));
348   EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
349 }
350 
TEST_F(GrammarAnnotatorTest,SetsFixedEntityData)351 TEST_F(GrammarAnnotatorTest, SetsFixedEntityData) {
352   // Create test rules.
353   GrammarModelT grammar_model;
354   SetTestTokenizerOptions(&grammar_model);
355   grammar_model.rules.reset(new grammar::RulesSetT);
356   grammar::LocaleShardMap locale_shard_map =
357       grammar::LocaleShardMap::CreateLocaleShardMap({""});
358   grammar::Rules rules(locale_shard_map);
359   const int person_result =
360       AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
361   rules.Add(
362       "<person>", {"barack", "obama"},
363       /*callback=*/
364       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
365       /*callback_param=*/person_result);
366 
367   // Add test entity data.
368   std::unique_ptr<MutableFlatbuffer> entity_data =
369       entity_data_builder_->NewRoot();
370   entity_data->Set("person", "Former President Barack Obama");
371   grammar_model.rule_classification_result[person_result]
372       ->serialized_entity_data = entity_data->Serialize();
373 
374   rules.Finalize().Serialize(/*include_debug_information=*/false,
375                              grammar_model.rules.get());
376   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
377   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
378 
379   std::vector<AnnotatedSpan> result;
380   EXPECT_TRUE(annotator.Annotate(
381       {Locale::FromBCP47("en")},
382       UTF8ToUnicodeText("I saw Barack Obama today", /*do_copy=*/false),
383       &result));
384   EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 18, "person")));
385 
386   // Check entity data.
387   // As we don't have generated code for the ad-hoc generated entity data
388   // schema, we have to check manually using field offsets.
389   const flatbuffers::Table* entity =
390       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
391           result.front().classification.front().serialized_entity_data.data()));
392   EXPECT_THAT(
393       entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
394       "Former President Barack Obama");
395 }
396 
TEST_F(GrammarAnnotatorTest,SetsEntityDataFromCapturingMatches)397 TEST_F(GrammarAnnotatorTest, SetsEntityDataFromCapturingMatches) {
398   // Create test rules.
399   GrammarModelT grammar_model;
400   SetTestTokenizerOptions(&grammar_model);
401   grammar_model.rules.reset(new grammar::RulesSetT);
402   grammar::LocaleShardMap locale_shard_map =
403       grammar::LocaleShardMap::CreateLocaleShardMap({""});
404   grammar::Rules rules(locale_shard_map);
405   const int person_result =
406       AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
407 
408   rules.Add("<person>", {"barack?", "obama"});
409   rules.Add("<person>", {"zapp?", "brannigan"});
410   rules.AddValueMapping("<captured_person>", {"<person>"},
411                         /*value=*/0);
412   rules.Add(
413       "<test>", {"<captured_person>"},
414       /*callback=*/
415       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
416       /*callback_param=*/person_result);
417 
418   // Set capturing group entity data information.
419   grammar_model.rule_classification_result[person_result]
420       ->capturing_group.emplace_back(new CapturingGroupT);
421   CapturingGroupT* group =
422       grammar_model.rule_classification_result[person_result]
423           ->capturing_group.back()
424           .get();
425   group->entity_field_path.reset(new FlatbufferFieldPathT);
426   group->entity_field_path->field.emplace_back(new FlatbufferFieldT);
427   group->entity_field_path->field.back()->field_name = "person";
428   group->normalization_options.reset(new NormalizationOptionsT);
429   group->normalization_options->codepointwise_normalization =
430       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
431 
432   rules.Finalize().Serialize(/*include_debug_information=*/false,
433                              grammar_model.rules.get());
434   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
435   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
436 
437   std::vector<AnnotatedSpan> result;
438   EXPECT_TRUE(annotator.Annotate(
439       {Locale::FromBCP47("en")},
440       UTF8ToUnicodeText("I saw Zapp Brannigan today", /*do_copy=*/false),
441       &result));
442   EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 20, "person")));
443 
444   // Check entity data.
445   // As we don't have generated code for the ad-hoc generated entity data
446   // schema, we have to check manually using field offsets.
447   const flatbuffers::Table* entity =
448       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
449           result.front().classification.front().serialized_entity_data.data()));
450   EXPECT_THAT(
451       entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
452       "ZAPP BRANNIGAN");
453 }
454 
TEST_F(GrammarAnnotatorTest,RespectsRuleModes)455 TEST_F(GrammarAnnotatorTest, RespectsRuleModes) {
456   // Create test rules.
457   GrammarModelT grammar_model;
458   SetTestTokenizerOptions(&grammar_model);
459   grammar_model.rules.reset(new grammar::RulesSetT);
460   grammar::LocaleShardMap locale_shard_map =
461       grammar::LocaleShardMap::CreateLocaleShardMap({""});
462   grammar::Rules rules(locale_shard_map);
463   rules.Add("<classification_carrier>", {"ei"});
464   rules.Add("<classification_carrier>", {"en"});
465   rules.Add("<selection_carrier>", {"ai"});
466   rules.Add("<selection_carrier>", {"bx"});
467   rules.Add("<annotation_carrier>", {"aa"});
468   rules.Add("<annotation_carrier>", {"lx"});
469   rules.Add("<flight_code>", {"<2_digits>"});
470   rules.Add("<flight_code>", {"<3_digits>"});
471   rules.Add("<flight_code>", {"<4_digits>"});
472   rules.Add(
473       "<flight>", {"<annotation_carrier>", "<flight_code>"},
474       /*callback=*/
475       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
476       /*callback_param=*/
477       AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
478   rules.Add(
479       "<flight>", {"<selection_carrier>", "<flight_code>"},
480       /*callback=*/
481       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
482       /*callback_param=*/
483       AddRuleClassificationResult("flight",
484                                   ModeFlag_CLASSIFICATION_AND_SELECTION, 1.0,
485                                   &grammar_model));
486   rules.Add(
487       "<flight>", {"<classification_carrier>", "<flight_code>"},
488       /*callback=*/
489       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
490       /*callback_param=*/
491       AddRuleClassificationResult("flight", ModeFlag_CLASSIFICATION, 1.0,
492                                   &grammar_model));
493   rules.Finalize().Serialize(/*include_debug_information=*/false,
494                              grammar_model.rules.get());
495   flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
496   GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
497 
498   const UnicodeText text = UTF8ToUnicodeText(
499       "My flight: LX 38 arriving at 4pm, I'll fly back on EI2014 but maybe "
500       "also on bx 222");
501   const std::vector<Locale> locales = {Locale::FromBCP47("en")};
502 
503   // Annotation, only high confidence pattern.
504   {
505     std::vector<AnnotatedSpan> result;
506     EXPECT_TRUE(annotator.Annotate(locales, text, &result));
507     EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight")));
508   }
509 
510   // Selection, annotation patterns + selection.
511   {
512     AnnotatedSpan selection;
513 
514     // Selects 'LX 38'.
515     EXPECT_TRUE(annotator.SuggestSelection(locales, text,
516                                            /*selection=*/CodepointSpan{14, 15},
517                                            &selection));
518     EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
519 
520     // Selects 'bx 222'.
521     EXPECT_TRUE(annotator.SuggestSelection(locales, text,
522                                            /*selection=*/CodepointSpan{76, 77},
523                                            &selection));
524     EXPECT_THAT(selection, IsAnnotatedSpan(76, 82, "flight"));
525 
526     // Doesn't select 'EI2014'.
527     EXPECT_FALSE(annotator.SuggestSelection(locales, text,
528                                             /*selection=*/CodepointSpan{51, 51},
529                                             &selection));
530   }
531 
532   // Classification, all patterns.
533   {
534     ClassificationResult result;
535 
536     // Classifies 'LX 38'.
537     EXPECT_TRUE(
538         annotator.ClassifyText(locales, text, CodepointSpan{11, 16}, &result));
539     EXPECT_THAT(result, IsClassificationResult("flight"));
540 
541     // Classifies 'EI2014'.
542     EXPECT_TRUE(
543         annotator.ClassifyText(locales, text, CodepointSpan{51, 57}, &result));
544     EXPECT_THAT(result, IsClassificationResult("flight"));
545 
546     // Classifies 'bx 222'.
547     EXPECT_TRUE(
548         annotator.ClassifyText(locales, text, CodepointSpan{76, 82}, &result));
549     EXPECT_THAT(result, IsClassificationResult("flight"));
550   }
551 }
552 
553 }  // namespace
554 }  // namespace libtextclassifier3
555