1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/pod_ner/utils.h"
18
19 #include <iterator>
20
21 #include "annotator/model_generated.h"
22 #include "annotator/types.h"
23 #include "utils/tokenizer-utils.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_split.h"
28
29 namespace libtextclassifier3 {
30 namespace {
31
32 using ::testing::IsEmpty;
33 using ::testing::Not;
34
35 using PodNerModel_::CollectionT;
36 using PodNerModel_::LabelT;
37 using PodNerModel_::Label_::BoiseType;
38 using PodNerModel_::Label_::BoiseType_BEGIN;
39 using PodNerModel_::Label_::BoiseType_END;
40 using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
41 using PodNerModel_::Label_::BoiseType_O;
42 using PodNerModel_::Label_::BoiseType_SINGLE;
43 using PodNerModel_::Label_::MentionType;
44 using PodNerModel_::Label_::MentionType_NAM;
45 using PodNerModel_::Label_::MentionType_NOM;
46 using PodNerModel_::Label_::MentionType_UNDEFINED;
47
48 constexpr float kPriorityScore = 0.;
49 const std::vector<std::string>& kCollectionNames =
50 *new std::vector<std::string>{"undefined", "location", "person", "art",
51 "organization", "entitiy", "xxx"};
52 const auto& kStringToBoiseType = *new absl::flat_hash_map<
53 absl::string_view, libtextclassifier3::PodNerModel_::Label_::BoiseType>({
54 {"B", libtextclassifier3::PodNerModel_::Label_::BoiseType_BEGIN},
55 {"O", libtextclassifier3::PodNerModel_::Label_::BoiseType_O},
56 {"I", libtextclassifier3::PodNerModel_::Label_::BoiseType_INTERMEDIATE},
57 {"S", libtextclassifier3::PodNerModel_::Label_::BoiseType_SINGLE},
58 {"E", libtextclassifier3::PodNerModel_::Label_::BoiseType_END},
59 });
60 const auto& kStringToMentionType = *new absl::flat_hash_map<
61 absl::string_view, libtextclassifier3::PodNerModel_::Label_::MentionType>(
62 {{"NAM", libtextclassifier3::PodNerModel_::Label_::MentionType_NAM},
63 {"NOM", libtextclassifier3::PodNerModel_::Label_::MentionType_NOM}});
CreateLabel(BoiseType boise_type,MentionType mention_type,int collection_id)64 LabelT CreateLabel(BoiseType boise_type, MentionType mention_type,
65 int collection_id) {
66 LabelT label;
67 label.boise_type = boise_type;
68 label.mention_type = mention_type;
69 label.collection_id = collection_id;
70 return label;
71 }
TagsToLabels(const std::vector<std::string> & tags)72 std::vector<PodNerModel_::LabelT> TagsToLabels(
73 const std::vector<std::string>& tags) {
74 std::vector<PodNerModel_::LabelT> labels;
75 for (const auto& tag : tags) {
76 if (tag == "O") {
77 labels.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
78 } else {
79 std::vector<absl::string_view> tag_parts = absl::StrSplit(tag, '-');
80 labels.emplace_back(CreateLabel(
81 kStringToBoiseType.at(tag_parts[0]),
82 kStringToMentionType.at(tag_parts[1]),
83 std::distance(
84 kCollectionNames.begin(),
85 std::find(kCollectionNames.begin(), kCollectionNames.end(),
86 std::string(tag_parts[2].substr(
87 tag_parts[2].rfind('/') + 1))))));
88 }
89 }
90 return labels;
91 }
92
GetCollections()93 std::vector<CollectionT> GetCollections() {
94 std::vector<CollectionT> collections;
95 for (const std::string& collection_name : kCollectionNames) {
96 CollectionT collection;
97 collection.name = collection_name;
98 collection.single_token_priority_score = kPriorityScore;
99 collection.multi_token_priority_score = kPriorityScore;
100 collections.emplace_back(collection);
101 }
102 return collections;
103 }
104
105 class ConvertTagsToAnnotatedSpansTest : public testing::TestWithParam<bool> {};
106 INSTANTIATE_TEST_SUITE_P(TagsAndLabelsTest, ConvertTagsToAnnotatedSpansTest,
107 testing::Values(true, false));
108
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesBIESequence)109 TEST_P(ConvertTagsToAnnotatedSpansTest,
110 ConvertTagsToAnnotatedSpansHandlesBIESequence) {
111 std::vector<AnnotatedSpan> annotations;
112 std::string text = "We met in New York City";
113 std::vector<std::string> tags = {"O",
114 "O",
115 "O",
116 "B-NAM-/saft/location",
117 "I-NAM-/saft/location",
118 "E-NAM-/saft/location"};
119 if (GetParam()) {
120 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
121 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
122 /*label_filter=*/{"NAM", "NOM"},
123 /*relaxed_inside_label_matching=*/false,
124 /*relaxed_label_category_matching=*/false, kPriorityScore,
125 &annotations));
126 } else {
127 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
128 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
129 GetCollections(),
130 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
131 /*relaxed_inside_label_matching=*/false,
132 /*relaxed_mention_type_matching=*/false, &annotations));
133 }
134
135 EXPECT_EQ(annotations.size(), 1);
136 EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
137 EXPECT_EQ(annotations[0].classification[0].collection, "location");
138 }
139
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesSSequence)140 TEST_P(ConvertTagsToAnnotatedSpansTest,
141 ConvertTagsToAnnotatedSpansHandlesSSequence) {
142 std::vector<AnnotatedSpan> annotations;
143 std::string text = "His father was it.";
144 std::vector<std::string> tags = {"O", "S-NAM-/saft/person", "O", "O"};
145 if (GetParam()) {
146 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
147 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
148 /*label_filter=*/{"NAM", "NOM"},
149 /*relaxed_inside_label_matching=*/false,
150 /*relaxed_label_category_matching=*/false, kPriorityScore,
151 &annotations));
152 } else {
153 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
154 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
155 GetCollections(),
156 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
157 /*relaxed_inside_label_matching=*/false,
158 /*relaxed_mention_type_matching=*/false, &annotations));
159 }
160
161 EXPECT_EQ(annotations.size(), 1);
162 EXPECT_EQ(annotations[0].span, CodepointSpan(4, 10));
163 EXPECT_EQ(annotations[0].classification[0].collection, "person");
164 }
165
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesMultiple)166 TEST_P(ConvertTagsToAnnotatedSpansTest,
167 ConvertTagsToAnnotatedSpansHandlesMultiple) {
168 std::vector<AnnotatedSpan> annotations;
169 std::string text =
170 "Jaromir Jagr, Barak Obama and I met in Google New York City";
171 std::vector<std::string> tags = {"B-NAM-/saft/person",
172 "E-NAM-/saft/person",
173 "B-NOM-/saft/person",
174 "E-NOM-/saft/person",
175 "O",
176 "O",
177 "O",
178 "O",
179 "S-NAM-/saft/organization",
180 "B-NAM-/saft/location",
181 "I-NAM-/saft/location",
182 "E-NAM-/saft/location"};
183 if (GetParam()) {
184 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
185 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
186 /*label_filter=*/{"NAM", "NOM"},
187 /*relaxed_inside_label_matching=*/false,
188 /*relaxed_label_category_matching=*/false, kPriorityScore,
189 &annotations));
190 } else {
191 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
192 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
193 GetCollections(),
194 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
195 /*relaxed_inside_label_matching=*/false,
196 /*relaxed_mention_type_matching=*/false, &annotations));
197
198 ASSERT_EQ(annotations.size(), 4);
199 EXPECT_EQ(annotations[0].span, CodepointSpan(0, 13));
200 ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
201 EXPECT_EQ(annotations[0].classification[0].collection, "person");
202 EXPECT_EQ(annotations[1].span, CodepointSpan(14, 25));
203 ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
204 EXPECT_EQ(annotations[1].classification[0].collection, "person");
205 EXPECT_EQ(annotations[2].span, CodepointSpan(39, 45));
206 ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
207 EXPECT_EQ(annotations[2].classification[0].collection, "organization");
208 EXPECT_EQ(annotations[3].span, CodepointSpan(46, 59));
209 ASSERT_THAT(annotations[3].classification, Not(IsEmpty()));
210 EXPECT_EQ(annotations[3].classification[0].collection, "location");
211 }
212 }
213
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesMultipleFirstTokenNotFirst)214 TEST_P(ConvertTagsToAnnotatedSpansTest,
215 ConvertTagsToAnnotatedSpansHandlesMultipleFirstTokenNotFirst) {
216 std::vector<AnnotatedSpan> annotations;
217 std::vector<Token> original_tokens = TokenizeOnSpace(
218 "Jaromir Jagr, Barak Obama and I met in Google New York City");
219 std::vector<std::string> tags = {"B-NOM-/saft/person",
220 "E-NOM-/saft/person",
221 "O",
222 "O",
223 "O",
224 "O",
225 "S-NAM-/saft/organization",
226 "B-NAM-/saft/location",
227 "I-NAM-/saft/location",
228 "E-NAM-/saft/location"};
229 if (GetParam()) {
230 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
231 VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
232 tags,
233 /*label_filter=*/{"NAM", "NOM"},
234 /*relaxed_inside_label_matching=*/false,
235 /*relaxed_label_category_matching=*/false, kPriorityScore,
236 &annotations));
237 } else {
238 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
239 VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
240 TagsToLabels(tags), GetCollections(),
241 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
242 /*relaxed_inside_label_matching=*/false,
243 /*relaxed_mention_type_matching=*/false, &annotations));
244 }
245
246 ASSERT_EQ(annotations.size(), 3);
247 EXPECT_EQ(annotations[0].span, CodepointSpan(14, 25));
248 ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
249 EXPECT_EQ(annotations[0].classification[0].collection, "person");
250 EXPECT_EQ(annotations[1].span, CodepointSpan(39, 45));
251 ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
252 EXPECT_EQ(annotations[1].classification[0].collection, "organization");
253 EXPECT_EQ(annotations[2].span, CodepointSpan(46, 59));
254 ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
255 EXPECT_EQ(annotations[2].classification[0].collection, "location");
256 }
257
TEST(PodNerUtilsTest,ConvertTagsToAnnotatedSpansInvalidCollection)258 TEST(PodNerUtilsTest, ConvertTagsToAnnotatedSpansInvalidCollection) {
259 std::vector<AnnotatedSpan> annotations;
260 std::string text = "We met in New York City";
261 std::vector<std::string> tags = {"O", "O", "S-NAM-/saft/invalid_collection"};
262
263 ASSERT_FALSE(ConvertTagsToAnnotatedSpans(
264 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
265 GetCollections(),
266 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
267 /*relaxed_inside_label_matching=*/false,
268 /*relaxed_mention_type_matching=*/false, &annotations));
269 }
270
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentStart)271 TEST_P(ConvertTagsToAnnotatedSpansTest,
272 ConvertTagsToAnnotatedSpansIgnoresInconsistentStart) {
273 std::vector<AnnotatedSpan> annotations;
274 std::string text = "We met in New York City";
275 std::vector<std::string> tags = {"O",
276 "O",
277 "O",
278 "B-NAM-/saft/xxx",
279 "I-NAM-/saft/location",
280 "E-NAM-/saft/location"};
281 if (GetParam()) {
282 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
283 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
284 /*label_filter=*/{"NAM", "NOM"},
285 /*relaxed_inside_label_matching=*/false,
286 /*relaxed_label_category_matching=*/false, kPriorityScore,
287 &annotations));
288 } else {
289 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
290 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
291 GetCollections(),
292 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
293 /*relaxed_inside_label_matching=*/false,
294 /*relaxed_mention_type_matching=*/false, &annotations));
295 }
296 EXPECT_THAT(annotations, IsEmpty());
297 }
298
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeStart)299 TEST_P(ConvertTagsToAnnotatedSpansTest,
300 ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeStart) {
301 std::vector<AnnotatedSpan> annotations;
302 std::string text = "We met in New York City";
303 std::vector<std::string> tags = {"O",
304 "O",
305 "O",
306 "B-NOM-/saft/location",
307 "I-NAM-/saft/location",
308 "E-NAM-/saft/location"};
309 if (GetParam()) {
310 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
311 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
312 /*label_filter=*/{"NAM", "NOM"},
313 /*relaxed_inside_label_matching=*/false,
314 /*relaxed_label_category_matching=*/false, kPriorityScore,
315 &annotations));
316 } else {
317 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
318 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
319 GetCollections(),
320 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
321 /*relaxed_inside_label_matching=*/false,
322 /*relaxed_mention_type_matching=*/false, &annotations));
323 }
324
325 EXPECT_THAT(annotations, IsEmpty());
326 }
327
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentInside)328 TEST_P(ConvertTagsToAnnotatedSpansTest,
329 ConvertTagsToAnnotatedSpansIgnoresInconsistentInside) {
330 std::vector<AnnotatedSpan> annotations;
331 std::string text = "We met in New York City";
332 std::vector<std::string> tags = {"O",
333 "O",
334 "O",
335 "B-NAM-/saft/location",
336 "I-NAM-/saft/xxx",
337 "E-NAM-/saft/location"};
338 if (GetParam()) {
339 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
340 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
341 /*label_filter=*/{"NAM", "NOM"},
342 /*relaxed_inside_label_matching=*/false,
343 /*relaxed_label_category_matching=*/false, kPriorityScore,
344 &annotations));
345 } else {
346 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
347 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
348 GetCollections(),
349 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
350 /*relaxed_inside_label_matching=*/false,
351 /*relaxed_mention_type_matching=*/false, &annotations));
352 }
353
354 EXPECT_THAT(annotations, IsEmpty());
355 }
356
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeInside)357 TEST_P(ConvertTagsToAnnotatedSpansTest,
358 ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeInside) {
359 std::vector<AnnotatedSpan> annotations;
360 std::string text = "We met in New York City";
361 std::vector<std::string> tags = {"O",
362 "O",
363 "O",
364 "B-NAM-/saft/location",
365 "I-NOM-/saft/location",
366 "E-NAM-/saft/location"};
367 if (GetParam()) {
368 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
369 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
370 /*label_filter=*/{"NAM", "NOM"},
371 /*relaxed_inside_label_matching=*/false,
372 /*relaxed_label_category_matching=*/false, kPriorityScore,
373 &annotations));
374 } else {
375 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
376 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
377 GetCollections(),
378 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
379 /*relaxed_inside_label_matching=*/false,
380 /*relaxed_mention_type_matching=*/false, &annotations));
381 }
382 EXPECT_THAT(annotations, IsEmpty());
383 }
384
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesInconsistentInside)385 TEST_P(ConvertTagsToAnnotatedSpansTest,
386 ConvertTagsToAnnotatedSpansHandlesInconsistentInside) {
387 std::vector<AnnotatedSpan> annotations;
388 std::string text = "We met in New York City";
389 std::vector<std::string> tags = {"O",
390 "O",
391 "O",
392 "B-NAM-/saft/location",
393 "I-NAM-/saft/xxx",
394 "E-NAM-/saft/location"};
395 if (GetParam()) {
396 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
397 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
398 /*label_filter=*/{"NAM", "NOM"},
399 /*relaxed_inside_label_matching=*/true,
400 /*relaxed_label_category_matching=*/false, kPriorityScore,
401 &annotations));
402 } else {
403 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
404 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
405 GetCollections(),
406 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
407 /*relaxed_inside_label_matching=*/true,
408 /*relaxed_mention_type_matching=*/false, &annotations));
409 }
410
411 EXPECT_EQ(annotations.size(), 1);
412 EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
413 EXPECT_EQ(annotations[0].classification[0].collection, "location");
414 }
415
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentEnd)416 TEST_P(ConvertTagsToAnnotatedSpansTest,
417 ConvertTagsToAnnotatedSpansIgnoresInconsistentEnd) {
418 std::vector<AnnotatedSpan> annotations;
419 std::string text = "We met in New York City";
420 std::vector<std::string> tags = {"O",
421 "O",
422 "O",
423 "B-NAM-/saft/location",
424 "I-NAM-/saft/location",
425 "E-NAM-/saft/xxx"};
426 if (GetParam()) {
427 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
428 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
429 /*label_filter=*/{"NAM", "NOM"},
430 /*relaxed_inside_label_matching=*/false,
431 /*relaxed_label_category_matching=*/false, kPriorityScore,
432 &annotations));
433 } else {
434 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
435 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
436 GetCollections(),
437 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
438 /*relaxed_inside_label_matching=*/false,
439 /*relaxed_mention_type_matching=*/false, &annotations));
440 }
441
442 EXPECT_THAT(annotations, IsEmpty());
443 }
444
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeEnd)445 TEST_P(ConvertTagsToAnnotatedSpansTest,
446 ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeEnd) {
447 std::vector<AnnotatedSpan> annotations;
448 std::string text = "We met in New York City";
449 std::vector<std::string> tags = {"O",
450 "O",
451 "O",
452 "B-NAM-/saft/location",
453 "I-NAM-/saft/location",
454 "E-NOM-/saft/location"};
455 if (GetParam()) {
456 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
457 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
458 /*label_filter=*/{"NAM", "NOM"},
459 /*relaxed_inside_label_matching=*/false,
460 /*relaxed_label_category_matching=*/false, kPriorityScore,
461 &annotations));
462 } else {
463 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
464 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
465 GetCollections(),
466 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
467 /*relaxed_inside_label_matching=*/false,
468 /*relaxed_mention_type_matching=*/false, &annotations));
469 }
470
471 EXPECT_THAT(annotations, IsEmpty());
472 }
473
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansHandlesInconsistentLabelTypeWhenEntityMatches)474 TEST_P(
475 ConvertTagsToAnnotatedSpansTest,
476 ConvertTagsToAnnotatedSpansHandlesInconsistentLabelTypeWhenEntityMatches) {
477 std::vector<AnnotatedSpan> annotations;
478 std::string text = "We met in New York City";
479 std::vector<std::string> tags = {"O",
480 "O",
481 "O",
482 "B-NOM-/saft/location",
483 "I-NOM-/saft/location",
484 "E-NAM-/saft/location"};
485 if (GetParam()) {
486 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
487 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
488 /*label_filter=*/{"NAM", "NOM"},
489 /*relaxed_inside_label_matching=*/false,
490 /*relaxed_label_category_matching=*/true, kPriorityScore,
491 &annotations));
492 } else {
493 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
494 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
495 GetCollections(),
496 /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
497 /*relaxed_inside_label_matching=*/false,
498 /*relaxed_mention_type_matching=*/true, &annotations));
499 }
500
501 EXPECT_EQ(annotations.size(), 1);
502 EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
503 EXPECT_EQ(annotations[0].classification[0].collection, "location");
504 }
505
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansIgnoresFilteredLabel)506 TEST_P(ConvertTagsToAnnotatedSpansTest,
507 ConvertTagsToAnnotatedSpansIgnoresFilteredLabel) {
508 std::vector<AnnotatedSpan> annotations;
509 std::string text = "We met in New York City";
510 std::vector<std::string> tags = {"O",
511 "O",
512 "O",
513 "B-NAM-/saft/location",
514 "I-NAM-/saft/location",
515 "E-NAM-/saft/location"};
516 if (GetParam()) {
517 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
518 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
519 /*label_filter=*/{"NOM"},
520 /*relaxed_inside_label_matching=*/false,
521 /*relaxed_label_category_matching=*/false, kPriorityScore,
522 &annotations));
523 } else {
524 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
525 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
526 GetCollections(),
527 /*mention_filter=*/{MentionType_NOM},
528 /*relaxed_inside_label_matching=*/false,
529 /*relaxed_mention_type_matching=*/false, &annotations));
530 }
531
532 EXPECT_THAT(annotations, IsEmpty());
533 }
534
TEST_P(ConvertTagsToAnnotatedSpansTest,ConvertTagsToAnnotatedSpansWithEmptyLabelFilterIgnoresAll)535 TEST_P(ConvertTagsToAnnotatedSpansTest,
536 ConvertTagsToAnnotatedSpansWithEmptyLabelFilterIgnoresAll) {
537 std::vector<AnnotatedSpan> annotations;
538 std::string text = "We met in New York City";
539 std::vector<std::string> tags = {"O",
540 "O",
541 "O",
542 "B-NOM-/saft/location",
543 "I-NOM-/saft/location",
544 "E-NOM-/saft/location"};
545 if (GetParam()) {
546 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
547 VectorSpan<Token>(TokenizeOnSpace(text)), tags,
548 /*label_filter=*/{},
549 /*relaxed_inside_label_matching=*/false,
550 /*relaxed_label_category_matching=*/false, kPriorityScore,
551 &annotations));
552 } else {
553 ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
554 VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
555 GetCollections(),
556 /*mention_filter=*/{},
557 /*relaxed_inside_label_matching=*/false,
558 /*relaxed_mention_type_matching=*/false, &annotations));
559 }
560
561 EXPECT_THAT(annotations, IsEmpty());
562 }
563
TEST(PodNerUtilsTest,MergeLabelsIntoLeftSequence)564 TEST(PodNerUtilsTest, MergeLabelsIntoLeftSequence) {
565 std::vector<PodNerModel_::LabelT> original_labels_left;
566 original_labels_left.emplace_back(
567 CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
568 original_labels_left.emplace_back(
569 CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
570 original_labels_left.emplace_back(
571 CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
572 original_labels_left.emplace_back(
573 CreateLabel(BoiseType_SINGLE, MentionType_NAM, 1));
574 original_labels_left.emplace_back(
575 CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
576 original_labels_left.emplace_back(
577 CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
578 original_labels_left.emplace_back(
579 CreateLabel(BoiseType_SINGLE, MentionType_NAM, 2));
580
581 std::vector<PodNerModel_::LabelT> labels_right;
582 labels_right.emplace_back(
583 CreateLabel(BoiseType_BEGIN, MentionType_UNDEFINED, 3));
584 labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
585 labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
586 labels_right.emplace_back(CreateLabel(BoiseType_BEGIN, MentionType_NAM, 4));
587 labels_right.emplace_back(
588 CreateLabel(BoiseType_INTERMEDIATE, MentionType_UNDEFINED, 4));
589 labels_right.emplace_back(
590 CreateLabel(BoiseType_END, MentionType_UNDEFINED, 4));
591 std::vector<PodNerModel_::LabelT> labels_left = original_labels_left;
592
593 ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
594 /*index_first_right_tag_in_left=*/3,
595 &labels_left));
596 EXPECT_EQ(labels_left.size(), 9);
597 EXPECT_EQ(labels_left[0].collection_id, 0);
598 EXPECT_EQ(labels_left[1].collection_id, 0);
599 EXPECT_EQ(labels_left[2].collection_id, 0);
600 EXPECT_EQ(labels_left[3].collection_id, 1);
601 EXPECT_EQ(labels_left[4].collection_id, 0);
602 EXPECT_EQ(labels_left[5].collection_id, 0);
603 EXPECT_EQ(labels_left[6].collection_id, 4);
604 EXPECT_EQ(labels_left[7].collection_id, 4);
605 EXPECT_EQ(labels_left[8].collection_id, 4);
606
607 labels_left = original_labels_left;
608 ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
609 /*index_first_right_tag_in_left=*/2,
610 &labels_left));
611 EXPECT_EQ(labels_left.size(), 8);
612 EXPECT_EQ(labels_left[0].collection_id, 0);
613 EXPECT_EQ(labels_left[1].collection_id, 0);
614 EXPECT_EQ(labels_left[2].collection_id, 0);
615 EXPECT_EQ(labels_left[3].collection_id, 1);
616 EXPECT_EQ(labels_left[4].collection_id, 0);
617 EXPECT_EQ(labels_left[5].collection_id, 4);
618 EXPECT_EQ(labels_left[6].collection_id, 4);
619 EXPECT_EQ(labels_left[7].collection_id, 4);
620 }
621
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanAllWordpices)622 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanAllWordpices) {
623 std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
624 {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
625 {"my", 12, 14}, {"name", 15, 19}};
626 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
627
628 WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
629 {2, 3}, tokens, word_starts,
630 /*num_wordpieces=*/12,
631 /*max_num_wordpieces_in_window=*/15);
632 EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
633 }
634
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanInMiddle)635 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanInMiddle) {
636 std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
637 {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
638 {"my", 12, 14}, {"name", 15, 19}};
639 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
640
641 WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
642 {6, 7}, tokens, word_starts,
643 /*num_wordpieces=*/12,
644 /*max_num_wordpieces_in_window=*/5);
645 EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 8));
646
647 wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
648 {6, 7}, tokens, word_starts,
649 /*num_wordpieces=*/12,
650 /*max_num_wordpieces_in_window=*/6);
651 EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 9));
652
653 wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
654 {12, 14}, tokens, word_starts,
655 /*num_wordpieces=*/12,
656 /*max_num_wordpieces_in_window=*/3);
657 EXPECT_EQ(wordpieceSpan, WordpieceSpan(9, 12));
658 }
659
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanCloseToStart)660 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToStart) {
661 std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
662 {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
663 {"my", 12, 14}, {"name", 15, 19}};
664 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
665
666 WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
667 {2, 3}, tokens, word_starts,
668 /*num_wordpieces=*/12,
669 /*max_num_wordpieces_in_window=*/7);
670 EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 7));
671 }
672
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanCloseToEnd)673 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToEnd) {
674 std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
675 {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
676 {"my", 12, 14}, {"name", 15, 19}};
677 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
678
679 WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
680 {15, 19}, tokens, word_starts,
681 /*num_wordpieces=*/12,
682 /*max_num_wordpieces_in_window=*/7);
683 EXPECT_EQ(wordpieceSpan, WordpieceSpan(5, 12));
684 }
685
TEST(PodNerUtilsTest,FindWordpiecesWindowAroundSpanBigSpan)686 TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanBigSpan) {
687 std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
688 {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
689 {"my", 12, 14}, {"name", 15, 19}};
690 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
691
692 WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
693 {0, 19}, tokens, word_starts,
694 /*num_wordpieces=*/12,
695 /*max_num_wordpieces_in_window=*/5);
696 EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
697 }
698
TEST(PodNerUtilsTest,FindFullTokensSpanInWindow)699 TEST(PodNerUtilsTest, FindFullTokensSpanInWindow) {
700 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
701 int first_token_index, num_tokens;
702 WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
703 word_starts, /*wordpiece_span=*/{0, 6},
704 /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
705 &num_tokens);
706 EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
707 EXPECT_EQ(first_token_index, 0);
708 EXPECT_EQ(num_tokens, 4);
709
710 updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
711 word_starts, /*wordpiece_span=*/{2, 6},
712 /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
713 &num_tokens);
714 EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(2, 6));
715 EXPECT_EQ(first_token_index, 1);
716 EXPECT_EQ(num_tokens, 3);
717 }
718
TEST(PodNerUtilsTest,FindFullTokensSpanInWindowStartInMiddleOfToken)719 TEST(PodNerUtilsTest, FindFullTokensSpanInWindowStartInMiddleOfToken) {
720 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
721 int first_token_index, num_tokens;
722 WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
723 word_starts, /*wordpiece_span=*/{1, 6},
724 /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
725 &num_tokens);
726 EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
727 EXPECT_EQ(first_token_index, 0);
728 EXPECT_EQ(num_tokens, 4);
729 }
730
TEST(PodNerUtilsTest,FindFullTokensSpanInWindowEndsInMiddleOfToken)731 TEST(PodNerUtilsTest, FindFullTokensSpanInWindowEndsInMiddleOfToken) {
732 std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
733 int first_token_index, num_tokens;
734 WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
735 word_starts, /*wordpiece_span=*/{1, 9},
736 /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
737 &num_tokens);
738 EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
739 EXPECT_EQ(first_token_index, 0);
740 EXPECT_EQ(num_tokens, 4);
741 }
TEST(PodNerUtilsTest,FindFirstFullTokenIndexSizeOne)742 TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeOne) {
743 std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
744 int index_first_full_token = internal::FindFirstFullTokenIndex(
745 word_starts, /*first_wordpiece_index=*/2);
746 EXPECT_EQ(index_first_full_token, 1);
747 }
748
TEST(PodNerUtilsTest,FindFirstFullTokenIndexFirst)749 TEST(PodNerUtilsTest, FindFirstFullTokenIndexFirst) {
750 std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
751 int index_first_full_token = internal::FindFirstFullTokenIndex(
752 word_starts, /*first_wordpiece_index=*/0);
753 EXPECT_EQ(index_first_full_token, 0);
754 }
755
TEST(PodNerUtilsTest,FindFirstFullTokenIndexSizeGreaterThanOne)756 TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeGreaterThanOne) {
757 std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
758 int index_first_full_token = internal::FindFirstFullTokenIndex(
759 word_starts, /*first_wordpiece_index=*/4);
760 EXPECT_EQ(index_first_full_token, 2);
761 }
762
TEST(PodNerUtilsTest,FindLastFullTokenIndexSizeOne)763 TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeOne) {
764 std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
765 int index_last_full_token = internal::FindLastFullTokenIndex(
766 word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/3);
767 EXPECT_EQ(index_last_full_token, 1);
768 }
769
TEST(PodNerUtilsTest,FindLastFullTokenIndexSizeGreaterThanOne)770 TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeGreaterThanOne) {
771 std::vector<int32_t> word_starts{1, 3, 4, 6, 8, 9};
772 int index_last_full_token = internal::FindLastFullTokenIndex(
773 word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/6);
774 EXPECT_EQ(index_last_full_token, 2);
775
776 index_last_full_token = internal::FindLastFullTokenIndex(
777 word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/7);
778 EXPECT_EQ(index_last_full_token, 2);
779
780 index_last_full_token = internal::FindLastFullTokenIndex(
781 word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/5);
782 EXPECT_EQ(index_last_full_token, 1);
783 }
784
TEST(PodNerUtilsTest,FindLastFullTokenIndexLast)785 TEST(PodNerUtilsTest, FindLastFullTokenIndexLast) {
786 std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
787 int index_last_full_token = internal::FindLastFullTokenIndex(
788 word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/12);
789 EXPECT_EQ(index_last_full_token, 7);
790
791 index_last_full_token = internal::FindLastFullTokenIndex(
792 word_starts, /*num_wordpieces=*/14, /*wordpiece_end=*/14);
793 EXPECT_EQ(index_last_full_token, 7);
794 }
795
TEST(PodNerUtilsTest,FindLastFullTokenIndexBeforeLast)796 TEST(PodNerUtilsTest, FindLastFullTokenIndexBeforeLast) {
797 std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
798 int index_last_full_token = internal::FindLastFullTokenIndex(
799 word_starts, /*num_wordpieces=*/15, /*wordpiece_end=*/12);
800 EXPECT_EQ(index_last_full_token, 6);
801 }
802
TEST(PodNerUtilsTest,ExpandWindowAndAlignSequenceSmallerThanMax)803 TEST(PodNerUtilsTest, ExpandWindowAndAlignSequenceSmallerThanMax) {
804 WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
805 /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/8,
806 /*wordpiece_span_to_expand=*/{2, 5});
807 EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 8));
808 }
809
TEST(PodNerUtilsTest,ExpandWindowAndAlignWindowLengthGreaterThanMax)810 TEST(PodNerUtilsTest, ExpandWindowAndAlignWindowLengthGreaterThanMax) {
811 WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
812 /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/100,
813 /*wordpiece_span_to_expand=*/{2, 51});
814 EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(2, 51));
815 }
816
TEST(PodNerUtilsTest,ExpandWindowAndAlignFirstIndexCloseToStart)817 TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToStart) {
818 WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
819 /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
820 /*wordpiece_span_to_expand=*/{2, 4});
821 EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 10));
822 }
823
TEST(PodNerUtilsTest,ExpandWindowAndAlignFirstIndexCloseToEnd)824 TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToEnd) {
825 WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
826 /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
827 /*wordpiece_span_to_expand=*/{18, 20});
828 EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(10, 20));
829 }
830
TEST(PodNerUtilsTest,ExpandWindowAndAlignFirstIndexInTheMiddle)831 TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexInTheMiddle) {
832 int window_first_wordpiece_index = 10;
833 int window_last_wordpiece_index = 11;
834 WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
835 /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
836 /*wordpiece_span_to_expand=*/{10, 12});
837 EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(6, 16));
838
839 window_first_wordpiece_index = 10;
840 window_last_wordpiece_index = 12;
841 maxWordpieceSpan = internal::ExpandWindowAndAlign(
842 /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
843 /*wordpiece_span_to_expand=*/{10, 13});
844 EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(7, 17));
845 }
846
TEST(PodNerUtilsTest,WindowGenerator)847 TEST(PodNerUtilsTest, WindowGenerator) {
848 std::vector<int32_t> wordpiece_indices = {10, 20, 30, 40, 50, 60, 70, 80};
849 std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
850 {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}};
851 std::vector<int32_t> token_starts{0, 2, 3, 5, 6, 7};
852 WindowGenerator window_generator(wordpiece_indices, token_starts, tokens,
853 /*max_num_wordpieces=*/4,
854 /*sliding_window_overlap=*/1,
855 /*span_of_interest=*/{0, 12});
856 VectorSpan<int32_t> cur_wordpiece_indices;
857 VectorSpan<int32_t> cur_token_starts;
858 VectorSpan<Token> cur_tokens;
859 ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
860 &cur_tokens));
861 ASSERT_FALSE(window_generator.Done());
862 ASSERT_EQ(cur_wordpiece_indices.size(), 3);
863 for (int i = 0; i < 3; i++) {
864 ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i]);
865 }
866 ASSERT_EQ(cur_token_starts.size(), 2);
867 ASSERT_EQ(cur_tokens.size(), 2);
868 for (int i = 0; i < cur_tokens.size(); i++) {
869 ASSERT_EQ(cur_token_starts[i], token_starts[i]);
870 ASSERT_EQ(cur_tokens[i], tokens[i]);
871 }
872
873 ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
874 &cur_tokens));
875 ASSERT_FALSE(window_generator.Done());
876 ASSERT_EQ(cur_wordpiece_indices.size(), 4);
877 for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
878 ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 2]);
879 }
880 ASSERT_EQ(cur_token_starts.size(), 3);
881 ASSERT_EQ(cur_tokens.size(), 3);
882 for (int i = 0; i < cur_tokens.size(); i++) {
883 ASSERT_EQ(cur_token_starts[i], token_starts[i + 1]);
884 ASSERT_EQ(cur_tokens[i], tokens[i + 1]);
885 }
886
887 ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
888 &cur_tokens));
889 ASSERT_TRUE(window_generator.Done());
890 ASSERT_EQ(cur_wordpiece_indices.size(), 3);
891 for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
892 ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 5]);
893 }
894 ASSERT_EQ(cur_token_starts.size(), 3);
895 ASSERT_EQ(cur_tokens.size(), 3);
896 for (int i = 0; i < cur_tokens.size(); i++) {
897 ASSERT_EQ(cur_token_starts[i], token_starts[i + 3]);
898 ASSERT_EQ(cur_tokens[i], tokens[i + 3]);
899 }
900
901 ASSERT_FALSE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
902 &cur_tokens));
903 }
904 } // namespace
905 } // namespace libtextclassifier3
906