/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/pod_ner/utils.h" #include #include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/tokenizer-utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" namespace libtextclassifier3 { namespace { using ::testing::IsEmpty; using ::testing::Not; using PodNerModel_::CollectionT; using PodNerModel_::LabelT; using PodNerModel_::Label_::BoiseType; using PodNerModel_::Label_::BoiseType_BEGIN; using PodNerModel_::Label_::BoiseType_END; using PodNerModel_::Label_::BoiseType_INTERMEDIATE; using PodNerModel_::Label_::BoiseType_O; using PodNerModel_::Label_::BoiseType_SINGLE; using PodNerModel_::Label_::MentionType; using PodNerModel_::Label_::MentionType_NAM; using PodNerModel_::Label_::MentionType_NOM; using PodNerModel_::Label_::MentionType_UNDEFINED; constexpr float kPriorityScore = 0.; const std::vector& kCollectionNames = *new std::vector{"undefined", "location", "person", "art", "organization", "entitiy", "xxx"}; const auto& kStringToBoiseType = *new absl::flat_hash_map< absl::string_view, libtextclassifier3::PodNerModel_::Label_::BoiseType>({ {"B", libtextclassifier3::PodNerModel_::Label_::BoiseType_BEGIN}, {"O", libtextclassifier3::PodNerModel_::Label_::BoiseType_O}, {"I", libtextclassifier3::PodNerModel_::Label_::BoiseType_INTERMEDIATE}, {"S", libtextclassifier3::PodNerModel_::Label_::BoiseType_SINGLE}, {"E", libtextclassifier3::PodNerModel_::Label_::BoiseType_END}, }); const auto& kStringToMentionType = *new absl::flat_hash_map< absl::string_view, libtextclassifier3::PodNerModel_::Label_::MentionType>( {{"NAM", libtextclassifier3::PodNerModel_::Label_::MentionType_NAM}, {"NOM", libtextclassifier3::PodNerModel_::Label_::MentionType_NOM}}); LabelT CreateLabel(BoiseType boise_type, MentionType mention_type, int collection_id) { LabelT label; label.boise_type = boise_type; label.mention_type = mention_type; label.collection_id = collection_id; return label; } std::vector TagsToLabels( const std::vector& tags) { std::vector labels; for (const auto& tag : tags) { if (tag == "O") { labels.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); } else { std::vector tag_parts = absl::StrSplit(tag, '-'); labels.emplace_back(CreateLabel( kStringToBoiseType.at(tag_parts[0]), kStringToMentionType.at(tag_parts[1]), std::distance( kCollectionNames.begin(), std::find(kCollectionNames.begin(), kCollectionNames.end(), std::string(tag_parts[2].substr( tag_parts[2].rfind('/') + 1)))))); } } return labels; } std::vector GetCollections() { std::vector collections; for (const std::string& collection_name : kCollectionNames) { CollectionT collection; collection.name = collection_name; collection.single_token_priority_score = kPriorityScore; collection.multi_token_priority_score = kPriorityScore; collections.emplace_back(collection); } return collections; } class ConvertTagsToAnnotatedSpansTest : public testing::TestWithParam {}; INSTANTIATE_TEST_SUITE_P(TagsAndLabelsTest, ConvertTagsToAnnotatedSpansTest, testing::Values(true, false)); TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansHandlesBIESequence) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/location", "I-NAM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_EQ(annotations.size(), 1); EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23)); EXPECT_EQ(annotations[0].classification[0].collection, "location"); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansHandlesSSequence) { std::vector annotations; std::string text = "His father was it."; std::vector tags = {"O", "S-NAM-/saft/person", "O", "O"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_EQ(annotations.size(), 1); EXPECT_EQ(annotations[0].span, CodepointSpan(4, 10)); EXPECT_EQ(annotations[0].classification[0].collection, "person"); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansHandlesMultiple) { std::vector annotations; std::string text = "Jaromir Jagr, Barak Obama and I met in Google New York City"; std::vector tags = {"B-NAM-/saft/person", "E-NAM-/saft/person", "B-NOM-/saft/person", "E-NOM-/saft/person", "O", "O", "O", "O", "S-NAM-/saft/organization", "B-NAM-/saft/location", "I-NAM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); ASSERT_EQ(annotations.size(), 4); EXPECT_EQ(annotations[0].span, CodepointSpan(0, 13)); ASSERT_THAT(annotations[0].classification, Not(IsEmpty())); EXPECT_EQ(annotations[0].classification[0].collection, "person"); EXPECT_EQ(annotations[1].span, CodepointSpan(14, 25)); ASSERT_THAT(annotations[1].classification, Not(IsEmpty())); EXPECT_EQ(annotations[1].classification[0].collection, "person"); EXPECT_EQ(annotations[2].span, CodepointSpan(39, 45)); ASSERT_THAT(annotations[2].classification, Not(IsEmpty())); EXPECT_EQ(annotations[2].classification[0].collection, "organization"); EXPECT_EQ(annotations[3].span, CodepointSpan(46, 59)); ASSERT_THAT(annotations[3].classification, Not(IsEmpty())); EXPECT_EQ(annotations[3].classification[0].collection, "location"); } } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansHandlesMultipleFirstTokenNotFirst) { std::vector annotations; std::vector original_tokens = TokenizeOnSpace( "Jaromir Jagr, Barak Obama and I met in Google New York City"); std::vector tags = {"B-NOM-/saft/person", "E-NOM-/saft/person", "O", "O", "O", "O", "S-NAM-/saft/organization", "B-NAM-/saft/location", "I-NAM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(original_tokens.begin() + 2, original_tokens.end()), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(original_tokens.begin() + 2, original_tokens.end()), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } ASSERT_EQ(annotations.size(), 3); EXPECT_EQ(annotations[0].span, CodepointSpan(14, 25)); ASSERT_THAT(annotations[0].classification, Not(IsEmpty())); EXPECT_EQ(annotations[0].classification[0].collection, "person"); EXPECT_EQ(annotations[1].span, CodepointSpan(39, 45)); ASSERT_THAT(annotations[1].classification, Not(IsEmpty())); EXPECT_EQ(annotations[1].classification[0].collection, "organization"); EXPECT_EQ(annotations[2].span, CodepointSpan(46, 59)); ASSERT_THAT(annotations[2].classification, Not(IsEmpty())); EXPECT_EQ(annotations[2].classification[0].collection, "location"); } TEST(PodNerUtilsTest, ConvertTagsToAnnotatedSpansInvalidCollection) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "S-NAM-/saft/invalid_collection"}; ASSERT_FALSE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansIgnoresInconsistentStart) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/xxx", "I-NAM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeStart) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NOM-/saft/location", "I-NAM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansIgnoresInconsistentInside) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/location", "I-NAM-/saft/xxx", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeInside) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/location", "I-NOM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansHandlesInconsistentInside) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/location", "I-NAM-/saft/xxx", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/true, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/true, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_EQ(annotations.size(), 1); EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23)); EXPECT_EQ(annotations[0].classification[0].collection, "location"); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansIgnoresInconsistentEnd) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/location", "I-NAM-/saft/location", "E-NAM-/saft/xxx"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeEnd) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/location", "I-NAM-/saft/location", "E-NOM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST_P( ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansHandlesInconsistentLabelTypeWhenEntityMatches) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NOM-/saft/location", "I-NOM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NAM", "NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/true, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NAM, MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/true, &annotations)); } EXPECT_EQ(annotations.size(), 1); EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23)); EXPECT_EQ(annotations[0].classification[0].collection, "location"); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansIgnoresFilteredLabel) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NAM-/saft/location", "I-NAM-/saft/location", "E-NAM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{"NOM"}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{MentionType_NOM}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST_P(ConvertTagsToAnnotatedSpansTest, ConvertTagsToAnnotatedSpansWithEmptyLabelFilterIgnoresAll) { std::vector annotations; std::string text = "We met in New York City"; std::vector tags = {"O", "O", "O", "B-NOM-/saft/location", "I-NOM-/saft/location", "E-NOM-/saft/location"}; if (GetParam()) { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), tags, /*label_filter=*/{}, /*relaxed_inside_label_matching=*/false, /*relaxed_label_category_matching=*/false, kPriorityScore, &annotations)); } else { ASSERT_TRUE(ConvertTagsToAnnotatedSpans( VectorSpan(TokenizeOnSpace(text)), TagsToLabels(tags), GetCollections(), /*mention_filter=*/{}, /*relaxed_inside_label_matching=*/false, /*relaxed_mention_type_matching=*/false, &annotations)); } EXPECT_THAT(annotations, IsEmpty()); } TEST(PodNerUtilsTest, MergeLabelsIntoLeftSequence) { std::vector original_labels_left; original_labels_left.emplace_back( CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); original_labels_left.emplace_back( CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); original_labels_left.emplace_back( CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); original_labels_left.emplace_back( CreateLabel(BoiseType_SINGLE, MentionType_NAM, 1)); original_labels_left.emplace_back( CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); original_labels_left.emplace_back( CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); original_labels_left.emplace_back( CreateLabel(BoiseType_SINGLE, MentionType_NAM, 2)); std::vector labels_right; labels_right.emplace_back( CreateLabel(BoiseType_BEGIN, MentionType_UNDEFINED, 3)); labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0)); labels_right.emplace_back(CreateLabel(BoiseType_BEGIN, MentionType_NAM, 4)); labels_right.emplace_back( CreateLabel(BoiseType_INTERMEDIATE, MentionType_UNDEFINED, 4)); labels_right.emplace_back( CreateLabel(BoiseType_END, MentionType_UNDEFINED, 4)); std::vector labels_left = original_labels_left; ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right, /*index_first_right_tag_in_left=*/3, &labels_left)); EXPECT_EQ(labels_left.size(), 9); EXPECT_EQ(labels_left[0].collection_id, 0); EXPECT_EQ(labels_left[1].collection_id, 0); EXPECT_EQ(labels_left[2].collection_id, 0); EXPECT_EQ(labels_left[3].collection_id, 1); EXPECT_EQ(labels_left[4].collection_id, 0); EXPECT_EQ(labels_left[5].collection_id, 0); EXPECT_EQ(labels_left[6].collection_id, 4); EXPECT_EQ(labels_left[7].collection_id, 4); EXPECT_EQ(labels_left[8].collection_id, 4); labels_left = original_labels_left; ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right, /*index_first_right_tag_in_left=*/2, &labels_left)); EXPECT_EQ(labels_left.size(), 8); EXPECT_EQ(labels_left[0].collection_id, 0); EXPECT_EQ(labels_left[1].collection_id, 0); EXPECT_EQ(labels_left[2].collection_id, 0); EXPECT_EQ(labels_left[3].collection_id, 1); EXPECT_EQ(labels_left[4].collection_id, 0); EXPECT_EQ(labels_left[5].collection_id, 4); EXPECT_EQ(labels_left[6].collection_id, 4); EXPECT_EQ(labels_left[7].collection_id, 4); } TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanAllWordpices) { std::vector tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5}, {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}, {"my", 12, 14}, {"name", 15, 19}}; std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan( {2, 3}, tokens, word_starts, /*num_wordpieces=*/12, /*max_num_wordpieces_in_window=*/15); EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12)); } TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanInMiddle) { std::vector tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5}, {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}, {"my", 12, 14}, {"name", 15, 19}}; std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan( {6, 7}, tokens, word_starts, /*num_wordpieces=*/12, /*max_num_wordpieces_in_window=*/5); EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 8)); wordpieceSpan = internal::FindWordpiecesWindowAroundSpan( {6, 7}, tokens, word_starts, /*num_wordpieces=*/12, /*max_num_wordpieces_in_window=*/6); EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 9)); wordpieceSpan = internal::FindWordpiecesWindowAroundSpan( {12, 14}, tokens, word_starts, /*num_wordpieces=*/12, /*max_num_wordpieces_in_window=*/3); EXPECT_EQ(wordpieceSpan, WordpieceSpan(9, 12)); } TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToStart) { std::vector tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5}, {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}, {"my", 12, 14}, {"name", 15, 19}}; std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan( {2, 3}, tokens, word_starts, /*num_wordpieces=*/12, /*max_num_wordpieces_in_window=*/7); EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 7)); } TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToEnd) { std::vector tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5}, {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}, {"my", 12, 14}, {"name", 15, 19}}; std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan( {15, 19}, tokens, word_starts, /*num_wordpieces=*/12, /*max_num_wordpieces_in_window=*/7); EXPECT_EQ(wordpieceSpan, WordpieceSpan(5, 12)); } TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanBigSpan) { std::vector tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5}, {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}, {"my", 12, 14}, {"name", 15, 19}}; std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan( {0, 19}, tokens, word_starts, /*num_wordpieces=*/12, /*max_num_wordpieces_in_window=*/5); EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12)); } TEST(PodNerUtilsTest, FindFullTokensSpanInWindow) { std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; int first_token_index, num_tokens; WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow( word_starts, /*wordpiece_span=*/{0, 6}, /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index, &num_tokens); EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6)); EXPECT_EQ(first_token_index, 0); EXPECT_EQ(num_tokens, 4); updated_wordpiece_span = internal::FindFullTokensSpanInWindow( word_starts, /*wordpiece_span=*/{2, 6}, /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index, &num_tokens); EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(2, 6)); EXPECT_EQ(first_token_index, 1); EXPECT_EQ(num_tokens, 3); } TEST(PodNerUtilsTest, FindFullTokensSpanInWindowStartInMiddleOfToken) { std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; int first_token_index, num_tokens; WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow( word_starts, /*wordpiece_span=*/{1, 6}, /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index, &num_tokens); EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6)); EXPECT_EQ(first_token_index, 0); EXPECT_EQ(num_tokens, 4); } TEST(PodNerUtilsTest, FindFullTokensSpanInWindowEndsInMiddleOfToken) { std::vector word_starts{0, 2, 3, 5, 6, 7, 10, 11}; int first_token_index, num_tokens; WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow( word_starts, /*wordpiece_span=*/{1, 9}, /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index, &num_tokens); EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6)); EXPECT_EQ(first_token_index, 0); EXPECT_EQ(num_tokens, 4); } TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeOne) { std::vector word_starts{1, 2, 3, 5, 6, 7, 10, 11}; int index_first_full_token = internal::FindFirstFullTokenIndex( word_starts, /*first_wordpiece_index=*/2); EXPECT_EQ(index_first_full_token, 1); } TEST(PodNerUtilsTest, FindFirstFullTokenIndexFirst) { std::vector word_starts{1, 2, 3, 5, 6, 7, 10, 11}; int index_first_full_token = internal::FindFirstFullTokenIndex( word_starts, /*first_wordpiece_index=*/0); EXPECT_EQ(index_first_full_token, 0); } TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeGreaterThanOne) { std::vector word_starts{1, 2, 3, 5, 6, 7, 10, 11}; int index_first_full_token = internal::FindFirstFullTokenIndex( word_starts, /*first_wordpiece_index=*/4); EXPECT_EQ(index_first_full_token, 2); } TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeOne) { std::vector word_starts{1, 2, 3, 5, 6, 7, 10, 11}; int index_last_full_token = internal::FindLastFullTokenIndex( word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/3); EXPECT_EQ(index_last_full_token, 1); } TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeGreaterThanOne) { std::vector word_starts{1, 3, 4, 6, 8, 9}; int index_last_full_token = internal::FindLastFullTokenIndex( word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/6); EXPECT_EQ(index_last_full_token, 2); index_last_full_token = internal::FindLastFullTokenIndex( word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/7); EXPECT_EQ(index_last_full_token, 2); index_last_full_token = internal::FindLastFullTokenIndex( word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/5); EXPECT_EQ(index_last_full_token, 1); } TEST(PodNerUtilsTest, FindLastFullTokenIndexLast) { std::vector word_starts{1, 2, 3, 5, 6, 7, 10, 11}; int index_last_full_token = internal::FindLastFullTokenIndex( word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/12); EXPECT_EQ(index_last_full_token, 7); index_last_full_token = internal::FindLastFullTokenIndex( word_starts, /*num_wordpieces=*/14, /*wordpiece_end=*/14); EXPECT_EQ(index_last_full_token, 7); } TEST(PodNerUtilsTest, FindLastFullTokenIndexBeforeLast) { std::vector word_starts{1, 2, 3, 5, 6, 7, 10, 11}; int index_last_full_token = internal::FindLastFullTokenIndex( word_starts, /*num_wordpieces=*/15, /*wordpiece_end=*/12); EXPECT_EQ(index_last_full_token, 6); } TEST(PodNerUtilsTest, ExpandWindowAndAlignSequenceSmallerThanMax) { WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign( /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/8, /*wordpiece_span_to_expand=*/{2, 5}); EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 8)); } TEST(PodNerUtilsTest, ExpandWindowAndAlignWindowLengthGreaterThanMax) { WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign( /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/100, /*wordpiece_span_to_expand=*/{2, 51}); EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(2, 51)); } TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToStart) { WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign( /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20, /*wordpiece_span_to_expand=*/{2, 4}); EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 10)); } TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToEnd) { WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign( /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20, /*wordpiece_span_to_expand=*/{18, 20}); EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(10, 20)); } TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexInTheMiddle) { int window_first_wordpiece_index = 10; int window_last_wordpiece_index = 11; WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign( /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20, /*wordpiece_span_to_expand=*/{10, 12}); EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(6, 16)); window_first_wordpiece_index = 10; window_last_wordpiece_index = 12; maxWordpieceSpan = internal::ExpandWindowAndAlign( /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20, /*wordpiece_span_to_expand=*/{10, 13}); EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(7, 17)); } TEST(PodNerUtilsTest, WindowGenerator) { std::vector wordpiece_indices = {10, 20, 30, 40, 50, 60, 70, 80}; std::vector tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5}, {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}}; std::vector token_starts{0, 2, 3, 5, 6, 7}; WindowGenerator window_generator(wordpiece_indices, token_starts, tokens, /*max_num_wordpieces=*/4, /*sliding_window_overlap=*/1, /*span_of_interest=*/{0, 12}); VectorSpan cur_wordpiece_indices; VectorSpan cur_token_starts; VectorSpan cur_tokens; ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts, &cur_tokens)); ASSERT_FALSE(window_generator.Done()); ASSERT_EQ(cur_wordpiece_indices.size(), 3); for (int i = 0; i < 3; i++) { ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i]); } ASSERT_EQ(cur_token_starts.size(), 2); ASSERT_EQ(cur_tokens.size(), 2); for (int i = 0; i < cur_tokens.size(); i++) { ASSERT_EQ(cur_token_starts[i], token_starts[i]); ASSERT_EQ(cur_tokens[i], tokens[i]); } ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts, &cur_tokens)); ASSERT_FALSE(window_generator.Done()); ASSERT_EQ(cur_wordpiece_indices.size(), 4); for (int i = 0; i < cur_wordpiece_indices.size(); i++) { ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 2]); } ASSERT_EQ(cur_token_starts.size(), 3); ASSERT_EQ(cur_tokens.size(), 3); for (int i = 0; i < cur_tokens.size(); i++) { ASSERT_EQ(cur_token_starts[i], token_starts[i + 1]); ASSERT_EQ(cur_tokens[i], tokens[i + 1]); } ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts, &cur_tokens)); ASSERT_TRUE(window_generator.Done()); ASSERT_EQ(cur_wordpiece_indices.size(), 3); for (int i = 0; i < cur_wordpiece_indices.size(); i++) { ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 5]); } ASSERT_EQ(cur_token_starts.size(), 3); ASSERT_EQ(cur_tokens.size(), 3); for (int i = 0; i < cur_tokens.size(); i++) { ASSERT_EQ(cur_token_starts[i], token_starts[i + 3]); ASSERT_EQ(cur_tokens[i], tokens[i + 3]); } ASSERT_FALSE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts, &cur_tokens)); } } // namespace } // namespace libtextclassifier3