1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/pod_ner/utils.h"
18 
19 #include <iterator>
20 
21 #include "annotator/model_generated.h"
22 #include "annotator/types.h"
23 #include "utils/tokenizer-utils.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_split.h"
28 
29 namespace libtextclassifier3 {
30 namespace {
31 
32 using ::testing::IsEmpty;
33 using ::testing::Not;
34 
35 using PodNerModel_::CollectionT;
36 using PodNerModel_::LabelT;
37 using PodNerModel_::Label_::BoiseType;
38 using PodNerModel_::Label_::BoiseType_BEGIN;
39 using PodNerModel_::Label_::BoiseType_END;
40 using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
41 using PodNerModel_::Label_::BoiseType_O;
42 using PodNerModel_::Label_::BoiseType_SINGLE;
43 using PodNerModel_::Label_::MentionType;
44 using PodNerModel_::Label_::MentionType_NAM;
45 using PodNerModel_::Label_::MentionType_NOM;
46 using PodNerModel_::Label_::MentionType_UNDEFINED;
47 
48 constexpr float kPriorityScore = 0.;
49 const std::vector<std::string>& kCollectionNames =
50     *new std::vector<std::string>{"undefined",    "location", "person", "art",
51                                   "organization", "entitiy",  "xxx"};
52 const auto& kStringToBoiseType = *new absl::flat_hash_map<
53     absl::string_view, libtextclassifier3::PodNerModel_::Label_::BoiseType>({
54     {"B", libtextclassifier3::PodNerModel_::Label_::BoiseType_BEGIN},
55     {"O", libtextclassifier3::PodNerModel_::Label_::BoiseType_O},
56     {"I", libtextclassifier3::PodNerModel_::Label_::BoiseType_INTERMEDIATE},
57     {"S", libtextclassifier3::PodNerModel_::Label_::BoiseType_SINGLE},
58     {"E", libtextclassifier3::PodNerModel_::Label_::BoiseType_END},
59 });
60 const auto& kStringToMentionType = *new absl::flat_hash_map<
61     absl::string_view, libtextclassifier3::PodNerModel_::Label_::MentionType>(
62     {{"NAM", libtextclassifier3::PodNerModel_::Label_::MentionType_NAM},
63      {"NOM", libtextclassifier3::PodNerModel_::Label_::MentionType_NOM}});
CreateLabel(BoiseType boise_type,MentionType mention_type,int collection_id)64 LabelT CreateLabel(BoiseType boise_type, MentionType mention_type,
65                    int collection_id) {
66   LabelT label;
67   label.boise_type = boise_type;
68   label.mention_type = mention_type;
69   label.collection_id = collection_id;
70   return label;
71 }
TagsToLabels(const std::vector<std::string> & tags)72 std::vector<PodNerModel_::LabelT> TagsToLabels(
73     const std::vector<std::string>& tags) {
74   std::vector<PodNerModel_::LabelT> labels;
75   for (const auto& tag : tags) {
76     if (tag == "O") {
77       labels.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
78     } else {
79       std::vector<absl::string_view> tag_parts = absl::StrSplit(tag, '-');
80       labels.emplace_back(CreateLabel(
81           kStringToBoiseType.at(tag_parts[0]),
82           kStringToMentionType.at(tag_parts[1]),
83           std::distance(
84               kCollectionNames.begin(),
85               std::find(kCollectionNames.begin(), kCollectionNames.end(),
86                         std::string(tag_parts[2].substr(
87                             tag_parts[2].rfind('/') + 1))))));
88     }
89   }
90   return labels;
91 }
92 
GetCollections()93 std::vector<CollectionT> GetCollections() {
94   std::vector<CollectionT> collections;
95   for (const std::string& collection_name : kCollectionNames) {
96     CollectionT collection;
97     collection.name = collection_name;
98     collection.single_token_priority_score = kPriorityScore;
99     collection.multi_token_priority_score = kPriorityScore;
100     collections.emplace_back(collection);
101   }
102   return collections;
103 }
104 
105 class ConvertTagsToAnnotatedSpansTest : public testing::TestWithParam<bool> {};
106 INSTANTIATE_TEST_SUITE_P(TagsAndLabelsTest, ConvertTagsToAnnotatedSpansTest,
107                          testing::Values(true, false));
108 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesBIESequence)109 TEST_P(ConvertTagsToAnnotatedSpansTest,
110        ConvertTagsToAnnotatedSpansHandlesBIESequence) {
111   std::vector<AnnotatedSpan> annotations;
112   std::string text = "We met in New York City";
113   std::vector<std::string> tags = {"O",
114                                    "O",
115                                    "O",
116                                    "B-NAM-/saft/location",
117                                    "I-NAM-/saft/location",
118                                    "E-NAM-/saft/location"};
119   if (GetParam()) {
120     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
121         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
122         /*label_filter=*/{"NAM", "NOM"},
123         /*relaxed_inside_label_matching=*/false,
124         /*relaxed_label_category_matching=*/false, kPriorityScore,
125         &annotations));
126   } else {
127     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
128         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
129         GetCollections(),
130         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
131         /*relaxed_inside_label_matching=*/false,
132         /*relaxed_mention_type_matching=*/false, &annotations));
133   }
134 
135   EXPECT_EQ(annotations.size(), 1);
136   EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
137   EXPECT_EQ(annotations[0].classification[0].collection, "location");
138 }
139 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesSSequence)140 TEST_P(ConvertTagsToAnnotatedSpansTest,
141        ConvertTagsToAnnotatedSpansHandlesSSequence) {
142   std::vector<AnnotatedSpan> annotations;
143   std::string text = "His father was it.";
144   std::vector<std::string> tags = {"O", "S-NAM-/saft/person", "O", "O"};
145   if (GetParam()) {
146     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
147         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
148         /*label_filter=*/{"NAM", "NOM"},
149         /*relaxed_inside_label_matching=*/false,
150         /*relaxed_label_category_matching=*/false, kPriorityScore,
151         &annotations));
152   } else {
153     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
154         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
155         GetCollections(),
156         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
157         /*relaxed_inside_label_matching=*/false,
158         /*relaxed_mention_type_matching=*/false, &annotations));
159   }
160 
161   EXPECT_EQ(annotations.size(), 1);
162   EXPECT_EQ(annotations[0].span, CodepointSpan(4, 10));
163   EXPECT_EQ(annotations[0].classification[0].collection, "person");
164 }
165 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesMultiple)166 TEST_P(ConvertTagsToAnnotatedSpansTest,
167        ConvertTagsToAnnotatedSpansHandlesMultiple) {
168   std::vector<AnnotatedSpan> annotations;
169   std::string text =
170       "Jaromir Jagr, Barak Obama and I met in Google New York City";
171   std::vector<std::string> tags = {"B-NAM-/saft/person",
172                                    "E-NAM-/saft/person",
173                                    "B-NOM-/saft/person",
174                                    "E-NOM-/saft/person",
175                                    "O",
176                                    "O",
177                                    "O",
178                                    "O",
179                                    "S-NAM-/saft/organization",
180                                    "B-NAM-/saft/location",
181                                    "I-NAM-/saft/location",
182                                    "E-NAM-/saft/location"};
183   if (GetParam()) {
184     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
185         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
186         /*label_filter=*/{"NAM", "NOM"},
187         /*relaxed_inside_label_matching=*/false,
188         /*relaxed_label_category_matching=*/false, kPriorityScore,
189         &annotations));
190   } else {
191     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
192         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
193         GetCollections(),
194         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
195         /*relaxed_inside_label_matching=*/false,
196         /*relaxed_mention_type_matching=*/false, &annotations));
197 
198     ASSERT_EQ(annotations.size(), 4);
199     EXPECT_EQ(annotations[0].span, CodepointSpan(0, 13));
200     ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
201     EXPECT_EQ(annotations[0].classification[0].collection, "person");
202     EXPECT_EQ(annotations[1].span, CodepointSpan(14, 25));
203     ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
204     EXPECT_EQ(annotations[1].classification[0].collection, "person");
205     EXPECT_EQ(annotations[2].span, CodepointSpan(39, 45));
206     ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
207     EXPECT_EQ(annotations[2].classification[0].collection, "organization");
208     EXPECT_EQ(annotations[3].span, CodepointSpan(46, 59));
209     ASSERT_THAT(annotations[3].classification, Not(IsEmpty()));
210     EXPECT_EQ(annotations[3].classification[0].collection, "location");
211   }
212 }
213 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesMultipleFirstTokenNotFirst)214 TEST_P(ConvertTagsToAnnotatedSpansTest,
215        ConvertTagsToAnnotatedSpansHandlesMultipleFirstTokenNotFirst) {
216   std::vector<AnnotatedSpan> annotations;
217   std::vector<Token> original_tokens = TokenizeOnSpace(
218       "Jaromir Jagr, Barak Obama and I met in Google New York City");
219   std::vector<std::string> tags = {"B-NOM-/saft/person",
220                                    "E-NOM-/saft/person",
221                                    "O",
222                                    "O",
223                                    "O",
224                                    "O",
225                                    "S-NAM-/saft/organization",
226                                    "B-NAM-/saft/location",
227                                    "I-NAM-/saft/location",
228                                    "E-NAM-/saft/location"};
229   if (GetParam()) {
230     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
231         VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
232         tags,
233         /*label_filter=*/{"NAM", "NOM"},
234         /*relaxed_inside_label_matching=*/false,
235         /*relaxed_label_category_matching=*/false, kPriorityScore,
236         &annotations));
237   } else {
238     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
239         VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
240         TagsToLabels(tags), GetCollections(),
241         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
242         /*relaxed_inside_label_matching=*/false,
243         /*relaxed_mention_type_matching=*/false, &annotations));
244   }
245 
246   ASSERT_EQ(annotations.size(), 3);
247   EXPECT_EQ(annotations[0].span, CodepointSpan(14, 25));
248   ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
249   EXPECT_EQ(annotations[0].classification[0].collection, "person");
250   EXPECT_EQ(annotations[1].span, CodepointSpan(39, 45));
251   ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
252   EXPECT_EQ(annotations[1].classification[0].collection, "organization");
253   EXPECT_EQ(annotations[2].span, CodepointSpan(46, 59));
254   ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
255   EXPECT_EQ(annotations[2].classification[0].collection, "location");
256 }
257 
TEST(PodNerUtilsTest,ConvertTagsToAnnotatedSpansInvalidCollection)258 TEST(PodNerUtilsTest, ConvertTagsToAnnotatedSpansInvalidCollection) {
259   std::vector<AnnotatedSpan> annotations;
260   std::string text = "We met in New York City";
261   std::vector<std::string> tags = {"O", "O", "S-NAM-/saft/invalid_collection"};
262 
263   ASSERT_FALSE(ConvertTagsToAnnotatedSpans(
264       VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
265       GetCollections(),
266       /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
267       /*relaxed_inside_label_matching=*/false,
268       /*relaxed_mention_type_matching=*/false, &annotations));
269 }
270 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentStart)271 TEST_P(ConvertTagsToAnnotatedSpansTest,
272        ConvertTagsToAnnotatedSpansIgnoresInconsistentStart) {
273   std::vector<AnnotatedSpan> annotations;
274   std::string text = "We met in New York City";
275   std::vector<std::string> tags = {"O",
276                                    "O",
277                                    "O",
278                                    "B-NAM-/saft/xxx",
279                                    "I-NAM-/saft/location",
280                                    "E-NAM-/saft/location"};
281   if (GetParam()) {
282     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
283         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
284         /*label_filter=*/{"NAM", "NOM"},
285         /*relaxed_inside_label_matching=*/false,
286         /*relaxed_label_category_matching=*/false, kPriorityScore,
287         &annotations));
288   } else {
289     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
290         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
291         GetCollections(),
292         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
293         /*relaxed_inside_label_matching=*/false,
294         /*relaxed_mention_type_matching=*/false, &annotations));
295   }
296   EXPECT_THAT(annotations, IsEmpty());
297 }
298 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeStart)299 TEST_P(ConvertTagsToAnnotatedSpansTest,
300        ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeStart) {
301   std::vector<AnnotatedSpan> annotations;
302   std::string text = "We met in New York City";
303   std::vector<std::string> tags = {"O",
304                                    "O",
305                                    "O",
306                                    "B-NOM-/saft/location",
307                                    "I-NAM-/saft/location",
308                                    "E-NAM-/saft/location"};
309   if (GetParam()) {
310     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
311         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
312         /*label_filter=*/{"NAM", "NOM"},
313         /*relaxed_inside_label_matching=*/false,
314         /*relaxed_label_category_matching=*/false, kPriorityScore,
315         &annotations));
316   } else {
317     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
318         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
319         GetCollections(),
320         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
321         /*relaxed_inside_label_matching=*/false,
322         /*relaxed_mention_type_matching=*/false, &annotations));
323   }
324 
325   EXPECT_THAT(annotations, IsEmpty());
326 }
327 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentInside)328 TEST_P(ConvertTagsToAnnotatedSpansTest,
329        ConvertTagsToAnnotatedSpansIgnoresInconsistentInside) {
330   std::vector<AnnotatedSpan> annotations;
331   std::string text = "We met in New York City";
332   std::vector<std::string> tags = {"O",
333                                    "O",
334                                    "O",
335                                    "B-NAM-/saft/location",
336                                    "I-NAM-/saft/xxx",
337                                    "E-NAM-/saft/location"};
338   if (GetParam()) {
339     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
340         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
341         /*label_filter=*/{"NAM", "NOM"},
342         /*relaxed_inside_label_matching=*/false,
343         /*relaxed_label_category_matching=*/false, kPriorityScore,
344         &annotations));
345   } else {
346     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
347         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
348         GetCollections(),
349         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
350         /*relaxed_inside_label_matching=*/false,
351         /*relaxed_mention_type_matching=*/false, &annotations));
352   }
353 
354   EXPECT_THAT(annotations, IsEmpty());
355 }
356 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeInside)357 TEST_P(ConvertTagsToAnnotatedSpansTest,
358        ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeInside) {
359   std::vector<AnnotatedSpan> annotations;
360   std::string text = "We met in New York City";
361   std::vector<std::string> tags = {"O",
362                                    "O",
363                                    "O",
364                                    "B-NAM-/saft/location",
365                                    "I-NOM-/saft/location",
366                                    "E-NAM-/saft/location"};
367   if (GetParam()) {
368     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
369         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
370         /*label_filter=*/{"NAM", "NOM"},
371         /*relaxed_inside_label_matching=*/false,
372         /*relaxed_label_category_matching=*/false, kPriorityScore,
373         &annotations));
374   } else {
375     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
376         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
377         GetCollections(),
378         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
379         /*relaxed_inside_label_matching=*/false,
380         /*relaxed_mention_type_matching=*/false, &annotations));
381   }
382   EXPECT_THAT(annotations, IsEmpty());
383 }
384 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesInconsistentInside)385 TEST_P(ConvertTagsToAnnotatedSpansTest,
386        ConvertTagsToAnnotatedSpansHandlesInconsistentInside) {
387   std::vector<AnnotatedSpan> annotations;
388   std::string text = "We met in New York City";
389   std::vector<std::string> tags = {"O",
390                                    "O",
391                                    "O",
392                                    "B-NAM-/saft/location",
393                                    "I-NAM-/saft/xxx",
394                                    "E-NAM-/saft/location"};
395   if (GetParam()) {
396     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
397         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
398         /*label_filter=*/{"NAM", "NOM"},
399         /*relaxed_inside_label_matching=*/true,
400         /*relaxed_label_category_matching=*/false, kPriorityScore,
401         &annotations));
402   } else {
403     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
404         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
405         GetCollections(),
406         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
407         /*relaxed_inside_label_matching=*/true,
408         /*relaxed_mention_type_matching=*/false, &annotations));
409   }
410 
411   EXPECT_EQ(annotations.size(), 1);
412   EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
413   EXPECT_EQ(annotations[0].classification[0].collection, "location");
414 }
415 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentEnd)416 TEST_P(ConvertTagsToAnnotatedSpansTest,
417        ConvertTagsToAnnotatedSpansIgnoresInconsistentEnd) {
418   std::vector<AnnotatedSpan> annotations;
419   std::string text = "We met in New York City";
420   std::vector<std::string> tags = {"O",
421                                    "O",
422                                    "O",
423                                    "B-NAM-/saft/location",
424                                    "I-NAM-/saft/location",
425                                    "E-NAM-/saft/xxx"};
426   if (GetParam()) {
427     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
428         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
429         /*label_filter=*/{"NAM", "NOM"},
430         /*relaxed_inside_label_matching=*/false,
431         /*relaxed_label_category_matching=*/false, kPriorityScore,
432         &annotations));
433   } else {
434     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
435         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
436         GetCollections(),
437         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
438         /*relaxed_inside_label_matching=*/false,
439         /*relaxed_mention_type_matching=*/false, &annotations));
440   }
441 
442   EXPECT_THAT(annotations, IsEmpty());
443 }
444 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeEnd)445 TEST_P(ConvertTagsToAnnotatedSpansTest,
446        ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeEnd) {
447   std::vector<AnnotatedSpan> annotations;
448   std::string text = "We met in New York City";
449   std::vector<std::string> tags = {"O",
450                                    "O",
451                                    "O",
452                                    "B-NAM-/saft/location",
453                                    "I-NAM-/saft/location",
454                                    "E-NOM-/saft/location"};
455   if (GetParam()) {
456     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
457         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
458         /*label_filter=*/{"NAM", "NOM"},
459         /*relaxed_inside_label_matching=*/false,
460         /*relaxed_label_category_matching=*/false, kPriorityScore,
461         &annotations));
462   } else {
463     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
464         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
465         GetCollections(),
466         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
467         /*relaxed_inside_label_matching=*/false,
468         /*relaxed_mention_type_matching=*/false, &annotations));
469   }
470 
471   EXPECT_THAT(annotations, IsEmpty());
472 }
473 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesInconsistentLabelTypeWhenEntityMatches)474 TEST_P(
475     ConvertTagsToAnnotatedSpansTest,
476     ConvertTagsToAnnotatedSpansHandlesInconsistentLabelTypeWhenEntityMatches) {
477   std::vector<AnnotatedSpan> annotations;
478   std::string text = "We met in New York City";
479   std::vector<std::string> tags = {"O",
480                                    "O",
481                                    "O",
482                                    "B-NOM-/saft/location",
483                                    "I-NOM-/saft/location",
484                                    "E-NAM-/saft/location"};
485   if (GetParam()) {
486     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
487         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
488         /*label_filter=*/{"NAM", "NOM"},
489         /*relaxed_inside_label_matching=*/false,
490         /*relaxed_label_category_matching=*/true, kPriorityScore,
491         &annotations));
492   } else {
493     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
494         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
495         GetCollections(),
496         /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
497         /*relaxed_inside_label_matching=*/false,
498         /*relaxed_mention_type_matching=*/true, &annotations));
499   }
500 
501   EXPECT_EQ(annotations.size(), 1);
502   EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
503   EXPECT_EQ(annotations[0].classification[0].collection, "location");
504 }
505 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresFilteredLabel)506 TEST_P(ConvertTagsToAnnotatedSpansTest,
507        ConvertTagsToAnnotatedSpansIgnoresFilteredLabel) {
508   std::vector<AnnotatedSpan> annotations;
509   std::string text = "We met in New York City";
510   std::vector<std::string> tags = {"O",
511                                    "O",
512                                    "O",
513                                    "B-NAM-/saft/location",
514                                    "I-NAM-/saft/location",
515                                    "E-NAM-/saft/location"};
516   if (GetParam()) {
517     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
518         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
519         /*label_filter=*/{"NOM"},
520         /*relaxed_inside_label_matching=*/false,
521         /*relaxed_label_category_matching=*/false, kPriorityScore,
522         &annotations));
523   } else {
524     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
525         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
526         GetCollections(),
527         /*mention_filter=*/{MentionType_NOM},
528         /*relaxed_inside_label_matching=*/false,
529         /*relaxed_mention_type_matching=*/false, &annotations));
530   }
531 
532   EXPECT_THAT(annotations, IsEmpty());
533 }
534 
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansWithEmptyLabelFilterIgnoresAll)535 TEST_P(ConvertTagsToAnnotatedSpansTest,
536        ConvertTagsToAnnotatedSpansWithEmptyLabelFilterIgnoresAll) {
537   std::vector<AnnotatedSpan> annotations;
538   std::string text = "We met in New York City";
539   std::vector<std::string> tags = {"O",
540                                    "O",
541                                    "O",
542                                    "B-NOM-/saft/location",
543                                    "I-NOM-/saft/location",
544                                    "E-NOM-/saft/location"};
545   if (GetParam()) {
546     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
547         VectorSpan<Token>(TokenizeOnSpace(text)), tags,
548         /*label_filter=*/{},
549         /*relaxed_inside_label_matching=*/false,
550         /*relaxed_label_category_matching=*/false, kPriorityScore,
551         &annotations));
552   } else {
553     ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
554         VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
555         GetCollections(),
556         /*mention_filter=*/{},
557         /*relaxed_inside_label_matching=*/false,
558         /*relaxed_mention_type_matching=*/false, &annotations));
559   }
560 
561   EXPECT_THAT(annotations, IsEmpty());
562 }
563 
TEST(PodNerUtilsTest,MergeLabelsIntoLeftSequence)564 TEST(PodNerUtilsTest, MergeLabelsIntoLeftSequence) {
565   std::vector<PodNerModel_::LabelT> original_labels_left;
566   original_labels_left.emplace_back(
567       CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
568   original_labels_left.emplace_back(
569       CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
570   original_labels_left.emplace_back(
571       CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
572   original_labels_left.emplace_back(
573       CreateLabel(BoiseType_SINGLE, MentionType_NAM, 1));
574   original_labels_left.emplace_back(
575       CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
576   original_labels_left.emplace_back(
577       CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
578   original_labels_left.emplace_back(
579       CreateLabel(BoiseType_SINGLE, MentionType_NAM, 2));
580 
581   std::vector<PodNerModel_::LabelT> labels_right;
582   labels_right.emplace_back(
583       CreateLabel(BoiseType_BEGIN, MentionType_UNDEFINED, 3));
584   labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
585   labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
586   labels_right.emplace_back(CreateLabel(BoiseType_BEGIN, MentionType_NAM, 4));
587   labels_right.emplace_back(
588       CreateLabel(BoiseType_INTERMEDIATE, MentionType_UNDEFINED, 4));
589   labels_right.emplace_back(
590       CreateLabel(BoiseType_END, MentionType_UNDEFINED, 4));
591   std::vector<PodNerModel_::LabelT> labels_left = original_labels_left;
592 
593   ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
594                                           /*index_first_right_tag_in_left=*/3,
595                                           &labels_left));
596   EXPECT_EQ(labels_left.size(), 9);
597   EXPECT_EQ(labels_left[0].collection_id, 0);
598   EXPECT_EQ(labels_left[1].collection_id, 0);
599   EXPECT_EQ(labels_left[2].collection_id, 0);
600   EXPECT_EQ(labels_left[3].collection_id, 1);
601   EXPECT_EQ(labels_left[4].collection_id, 0);
602   EXPECT_EQ(labels_left[5].collection_id, 0);
603   EXPECT_EQ(labels_left[6].collection_id, 4);
604   EXPECT_EQ(labels_left[7].collection_id, 4);
605   EXPECT_EQ(labels_left[8].collection_id, 4);
606 
607   labels_left = original_labels_left;
608   ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
609                                           /*index_first_right_tag_in_left=*/2,
610                                           &labels_left));
611   EXPECT_EQ(labels_left.size(), 8);
612   EXPECT_EQ(labels_left[0].collection_id, 0);
613   EXPECT_EQ(labels_left[1].collection_id, 0);
614   EXPECT_EQ(labels_left[2].collection_id, 0);
615   EXPECT_EQ(labels_left[3].collection_id, 1);
616   EXPECT_EQ(labels_left[4].collection_id, 0);
617   EXPECT_EQ(labels_left[5].collection_id, 4);
618   EXPECT_EQ(labels_left[6].collection_id, 4);
619   EXPECT_EQ(labels_left[7].collection_id, 4);
620 }
621 
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanAllWordpices)622 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanAllWordpices) {
623   std::vector<Token> tokens{{"a", 0, 1},    {"b", 2, 3},     {"c", 4, 5},
624                             {"d", 6, 7},    {"e", 8, 9},     {"f", 10, 11},
625                             {"my", 12, 14}, {"name", 15, 19}};
626   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
627 
628   WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
629       {2, 3}, tokens, word_starts,
630       /*num_wordpieces=*/12,
631       /*max_num_wordpieces_in_window=*/15);
632   EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
633 }
634 
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanInMiddle)635 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanInMiddle) {
636   std::vector<Token> tokens{{"a", 0, 1},    {"b", 2, 3},     {"c", 4, 5},
637                             {"d", 6, 7},    {"e", 8, 9},     {"f", 10, 11},
638                             {"my", 12, 14}, {"name", 15, 19}};
639   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
640 
641   WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
642       {6, 7}, tokens, word_starts,
643       /*num_wordpieces=*/12,
644       /*max_num_wordpieces_in_window=*/5);
645   EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 8));
646 
647   wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
648       {6, 7}, tokens, word_starts,
649       /*num_wordpieces=*/12,
650       /*max_num_wordpieces_in_window=*/6);
651   EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 9));
652 
653   wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
654       {12, 14}, tokens, word_starts,
655       /*num_wordpieces=*/12,
656       /*max_num_wordpieces_in_window=*/3);
657   EXPECT_EQ(wordpieceSpan, WordpieceSpan(9, 12));
658 }
659 
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanCloseToStart)660 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToStart) {
661   std::vector<Token> tokens{{"a", 0, 1},    {"b", 2, 3},     {"c", 4, 5},
662                             {"d", 6, 7},    {"e", 8, 9},     {"f", 10, 11},
663                             {"my", 12, 14}, {"name", 15, 19}};
664   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
665 
666   WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
667       {2, 3}, tokens, word_starts,
668       /*num_wordpieces=*/12,
669       /*max_num_wordpieces_in_window=*/7);
670   EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 7));
671 }
672 
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanCloseToEnd)673 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToEnd) {
674   std::vector<Token> tokens{{"a", 0, 1},    {"b", 2, 3},     {"c", 4, 5},
675                             {"d", 6, 7},    {"e", 8, 9},     {"f", 10, 11},
676                             {"my", 12, 14}, {"name", 15, 19}};
677   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
678 
679   WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
680       {15, 19}, tokens, word_starts,
681       /*num_wordpieces=*/12,
682       /*max_num_wordpieces_in_window=*/7);
683   EXPECT_EQ(wordpieceSpan, WordpieceSpan(5, 12));
684 }
685 
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanBigSpan)686 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanBigSpan) {
687   std::vector<Token> tokens{{"a", 0, 1},    {"b", 2, 3},     {"c", 4, 5},
688                             {"d", 6, 7},    {"e", 8, 9},     {"f", 10, 11},
689                             {"my", 12, 14}, {"name", 15, 19}};
690   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
691 
692   WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
693       {0, 19}, tokens, word_starts,
694       /*num_wordpieces=*/12,
695       /*max_num_wordpieces_in_window=*/5);
696   EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
697 }
698 
TEST(PodNerUtilsTest,FindFullTokensSpanInWindow)699 TEST(PodNerUtilsTest, FindFullTokensSpanInWindow) {
700   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
701   int first_token_index, num_tokens;
702   WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
703       word_starts, /*wordpiece_span=*/{0, 6},
704       /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
705       &num_tokens);
706   EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
707   EXPECT_EQ(first_token_index, 0);
708   EXPECT_EQ(num_tokens, 4);
709 
710   updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
711       word_starts, /*wordpiece_span=*/{2, 6},
712       /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
713       &num_tokens);
714   EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(2, 6));
715   EXPECT_EQ(first_token_index, 1);
716   EXPECT_EQ(num_tokens, 3);
717 }
718 
TEST(PodNerUtilsTest,FindFullTokensSpanInWindowStartInMiddleOfToken)719 TEST(PodNerUtilsTest, FindFullTokensSpanInWindowStartInMiddleOfToken) {
720   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
721   int first_token_index, num_tokens;
722   WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
723       word_starts, /*wordpiece_span=*/{1, 6},
724       /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
725       &num_tokens);
726   EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
727   EXPECT_EQ(first_token_index, 0);
728   EXPECT_EQ(num_tokens, 4);
729 }
730 
TEST(PodNerUtilsTest,FindFullTokensSpanInWindowEndsInMiddleOfToken)731 TEST(PodNerUtilsTest, FindFullTokensSpanInWindowEndsInMiddleOfToken) {
732   std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
733   int first_token_index, num_tokens;
734   WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
735       word_starts, /*wordpiece_span=*/{1, 9},
736       /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
737       &num_tokens);
738   EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
739   EXPECT_EQ(first_token_index, 0);
740   EXPECT_EQ(num_tokens, 4);
741 }
TEST(PodNerUtilsTest,FindFirstFullTokenIndexSizeOne)742 TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeOne) {
743   std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
744   int index_first_full_token = internal::FindFirstFullTokenIndex(
745       word_starts, /*first_wordpiece_index=*/2);
746   EXPECT_EQ(index_first_full_token, 1);
747 }
748 
TEST(PodNerUtilsTest,FindFirstFullTokenIndexFirst)749 TEST(PodNerUtilsTest, FindFirstFullTokenIndexFirst) {
750   std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
751   int index_first_full_token = internal::FindFirstFullTokenIndex(
752       word_starts, /*first_wordpiece_index=*/0);
753   EXPECT_EQ(index_first_full_token, 0);
754 }
755 
TEST(PodNerUtilsTest,FindFirstFullTokenIndexSizeGreaterThanOne)756 TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeGreaterThanOne) {
757   std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
758   int index_first_full_token = internal::FindFirstFullTokenIndex(
759       word_starts, /*first_wordpiece_index=*/4);
760   EXPECT_EQ(index_first_full_token, 2);
761 }
762 
TEST(PodNerUtilsTest,FindLastFullTokenIndexSizeOne)763 TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeOne) {
764   std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
765   int index_last_full_token = internal::FindLastFullTokenIndex(
766       word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/3);
767   EXPECT_EQ(index_last_full_token, 1);
768 }
769 
TEST(PodNerUtilsTest,FindLastFullTokenIndexSizeGreaterThanOne)770 TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeGreaterThanOne) {
771   std::vector<int32_t> word_starts{1, 3, 4, 6, 8, 9};
772   int index_last_full_token = internal::FindLastFullTokenIndex(
773       word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/6);
774   EXPECT_EQ(index_last_full_token, 2);
775 
776   index_last_full_token = internal::FindLastFullTokenIndex(
777       word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/7);
778   EXPECT_EQ(index_last_full_token, 2);
779 
780   index_last_full_token = internal::FindLastFullTokenIndex(
781       word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/5);
782   EXPECT_EQ(index_last_full_token, 1);
783 }
784 
TEST(PodNerUtilsTest,FindLastFullTokenIndexLast)785 TEST(PodNerUtilsTest, FindLastFullTokenIndexLast) {
786   std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
787   int index_last_full_token = internal::FindLastFullTokenIndex(
788       word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/12);
789   EXPECT_EQ(index_last_full_token, 7);
790 
791   index_last_full_token = internal::FindLastFullTokenIndex(
792       word_starts, /*num_wordpieces=*/14, /*wordpiece_end=*/14);
793   EXPECT_EQ(index_last_full_token, 7);
794 }
795 
TEST(PodNerUtilsTest,FindLastFullTokenIndexBeforeLast)796 TEST(PodNerUtilsTest, FindLastFullTokenIndexBeforeLast) {
797   std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
798   int index_last_full_token = internal::FindLastFullTokenIndex(
799       word_starts, /*num_wordpieces=*/15, /*wordpiece_end=*/12);
800   EXPECT_EQ(index_last_full_token, 6);
801 }
802 
TEST(PodNerUtilsTest,ExpandWindowAndAlignSequenceSmallerThanMax)803 TEST(PodNerUtilsTest, ExpandWindowAndAlignSequenceSmallerThanMax) {
804   WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
805       /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/8,
806       /*wordpiece_span_to_expand=*/{2, 5});
807   EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 8));
808 }
809 
TEST(PodNerUtilsTest,ExpandWindowAndAlignWindowLengthGreaterThanMax)810 TEST(PodNerUtilsTest, ExpandWindowAndAlignWindowLengthGreaterThanMax) {
811   WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
812       /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/100,
813       /*wordpiece_span_to_expand=*/{2, 51});
814   EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(2, 51));
815 }
816 
TEST(PodNerUtilsTest,ExpandWindowAndAlignFirstIndexCloseToStart)817 TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToStart) {
818   WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
819       /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
820       /*wordpiece_span_to_expand=*/{2, 4});
821   EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 10));
822 }
823 
TEST(PodNerUtilsTest,ExpandWindowAndAlignFirstIndexCloseToEnd)824 TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToEnd) {
825   WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
826       /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
827       /*wordpiece_span_to_expand=*/{18, 20});
828   EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(10, 20));
829 }
830 
TEST(PodNerUtilsTest,ExpandWindowAndAlignFirstIndexInTheMiddle)831 TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexInTheMiddle) {
832   int window_first_wordpiece_index = 10;
833   int window_last_wordpiece_index = 11;
834   WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
835       /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
836       /*wordpiece_span_to_expand=*/{10, 12});
837   EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(6, 16));
838 
839   window_first_wordpiece_index = 10;
840   window_last_wordpiece_index = 12;
841   maxWordpieceSpan = internal::ExpandWindowAndAlign(
842       /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
843       /*wordpiece_span_to_expand=*/{10, 13});
844   EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(7, 17));
845 }
846 
TEST(PodNerUtilsTest,WindowGenerator)847 TEST(PodNerUtilsTest, WindowGenerator) {
848   std::vector<int32_t> wordpiece_indices = {10, 20, 30, 40, 50, 60, 70, 80};
849   std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
850                             {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}};
851   std::vector<int32_t> token_starts{0, 2, 3, 5, 6, 7};
852   WindowGenerator window_generator(wordpiece_indices, token_starts, tokens,
853                                    /*max_num_wordpieces=*/4,
854                                    /*sliding_window_overlap=*/1,
855                                    /*span_of_interest=*/{0, 12});
856   VectorSpan<int32_t> cur_wordpiece_indices;
857   VectorSpan<int32_t> cur_token_starts;
858   VectorSpan<Token> cur_tokens;
859   ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
860                                     &cur_tokens));
861   ASSERT_FALSE(window_generator.Done());
862   ASSERT_EQ(cur_wordpiece_indices.size(), 3);
863   for (int i = 0; i < 3; i++) {
864     ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i]);
865   }
866   ASSERT_EQ(cur_token_starts.size(), 2);
867   ASSERT_EQ(cur_tokens.size(), 2);
868   for (int i = 0; i < cur_tokens.size(); i++) {
869     ASSERT_EQ(cur_token_starts[i], token_starts[i]);
870     ASSERT_EQ(cur_tokens[i], tokens[i]);
871   }
872 
873   ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
874                                     &cur_tokens));
875   ASSERT_FALSE(window_generator.Done());
876   ASSERT_EQ(cur_wordpiece_indices.size(), 4);
877   for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
878     ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 2]);
879   }
880   ASSERT_EQ(cur_token_starts.size(), 3);
881   ASSERT_EQ(cur_tokens.size(), 3);
882   for (int i = 0; i < cur_tokens.size(); i++) {
883     ASSERT_EQ(cur_token_starts[i], token_starts[i + 1]);
884     ASSERT_EQ(cur_tokens[i], tokens[i + 1]);
885   }
886 
887   ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
888                                     &cur_tokens));
889   ASSERT_TRUE(window_generator.Done());
890   ASSERT_EQ(cur_wordpiece_indices.size(), 3);
891   for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
892     ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 5]);
893   }
894   ASSERT_EQ(cur_token_starts.size(), 3);
895   ASSERT_EQ(cur_tokens.size(), 3);
896   for (int i = 0; i < cur_tokens.size(); i++) {
897     ASSERT_EQ(cur_token_starts[i], token_starts[i + 3]);
898     ASSERT_EQ(cur_tokens[i], tokens[i + 3]);
899   }
900 
901   ASSERT_FALSE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
902                                      &cur_tokens));
903 }
904 }  // namespace
905 }  // namespace libtextclassifier3
906