/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/pod_ner/utils.h" #include #include #include #include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/base/logging.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" namespace libtextclassifier3 { namespace { // Returns true if the needle string is contained in the haystack. bool StrIsOneOf(const std::string &needle, const std::vector &haystack) { return std::find(haystack.begin(), haystack.end(), needle) != haystack.end(); } // Finds the wordpiece span of the tokens in the given span. WordpieceSpan CodepointSpanToWordpieceSpan( const CodepointSpan &span, const std::vector &tokens, const std::vector &word_starts, int num_wordpieces) { int span_first_wordpiece_index = 0; int span_last_wordpiece_index = num_wordpieces; for (int i = 0; i < tokens.size(); i++) { if (tokens[i].start <= span.first && span.first < tokens[i].end) { span_first_wordpiece_index = word_starts[i]; } if (tokens[i].start <= span.second && span.second <= tokens[i].end) { span_last_wordpiece_index = (i + 1) < word_starts.size() ? word_starts[i + 1] : num_wordpieces; break; } } return WordpieceSpan(span_first_wordpiece_index, span_last_wordpiece_index); } } // namespace std::string SaftLabelToCollection(absl::string_view saft_label) { return std::string(saft_label.substr(saft_label.rfind('/') + 1)); } namespace internal { int FindLastFullTokenIndex(const std::vector &word_starts, int num_wordpieces, int wordpiece_end) { if (word_starts.empty()) { return 0; } if (*word_starts.rbegin() < wordpiece_end && num_wordpieces <= wordpiece_end) { // Last token. return word_starts.size() - 1; } for (int i = word_starts.size() - 1; i > 0; --i) { if (word_starts[i] <= wordpiece_end) { return (i - 1); } } return 0; } int FindFirstFullTokenIndex(const std::vector &word_starts, int first_wordpiece_index) { for (int i = 0; i < word_starts.size(); ++i) { if (word_starts[i] == first_wordpiece_index) { return i; } else if (word_starts[i] > first_wordpiece_index) { return std::max(0, i - 1); } } return std::max(0, static_cast(word_starts.size()) - 1); } WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window, int num_wordpieces, WordpieceSpan wordpiece_span_to_expand) { if (wordpiece_span_to_expand.length() >= max_num_wordpieces_in_window) { return wordpiece_span_to_expand; } int window_first_wordpiece_index = std::max( 0, wordpiece_span_to_expand.begin - ((max_num_wordpieces_in_window - wordpiece_span_to_expand.length()) / 2)); if ((window_first_wordpiece_index + max_num_wordpieces_in_window) > num_wordpieces) { window_first_wordpiece_index = std::max(num_wordpieces - max_num_wordpieces_in_window, 0); } return WordpieceSpan( window_first_wordpiece_index, std::min(window_first_wordpiece_index + max_num_wordpieces_in_window, num_wordpieces)); } WordpieceSpan FindWordpiecesWindowAroundSpan( const CodepointSpan &span_of_interest, const std::vector &tokens, const std::vector &word_starts, int num_wordpieces, int max_num_wordpieces_in_window) { WordpieceSpan wordpiece_span_to_expand = CodepointSpanToWordpieceSpan( span_of_interest, tokens, word_starts, num_wordpieces); WordpieceSpan max_wordpiece_span = ExpandWindowAndAlign( max_num_wordpieces_in_window, num_wordpieces, wordpiece_span_to_expand); return max_wordpiece_span; } WordpieceSpan FindFullTokensSpanInWindow( const std::vector &word_starts, const WordpieceSpan &wordpiece_span, int max_num_wordpieces, int num_wordpieces, int *first_token_index, int *num_tokens) { int window_first_wordpiece_index = wordpiece_span.begin; *first_token_index = internal::FindFirstFullTokenIndex( word_starts, window_first_wordpiece_index); window_first_wordpiece_index = word_starts[*first_token_index]; // Need to update the last index in case the first moved backward. int wordpiece_window_end = std::min( wordpiece_span.end, window_first_wordpiece_index + max_num_wordpieces); int last_token_index; last_token_index = internal::FindLastFullTokenIndex( word_starts, num_wordpieces, wordpiece_window_end); wordpiece_window_end = last_token_index == (word_starts.size() - 1) ? num_wordpieces : word_starts[last_token_index + 1]; *num_tokens = last_token_index - *first_token_index + 1; return WordpieceSpan(window_first_wordpiece_index, wordpiece_window_end); } } // namespace internal WindowGenerator::WindowGenerator(const std::vector &wordpiece_indices, const std::vector &token_starts, const std::vector &tokens, int max_num_wordpieces, int sliding_window_overlap, const CodepointSpan &span_of_interest) : wordpiece_indices_(&wordpiece_indices), token_starts_(&token_starts), tokens_(&tokens), max_num_effective_wordpieces_(max_num_wordpieces), sliding_window_num_wordpieces_overlap_(sliding_window_overlap) { entire_wordpiece_span_ = internal::FindWordpiecesWindowAroundSpan( span_of_interest, tokens, token_starts, wordpiece_indices.size(), max_num_wordpieces); next_wordpiece_span_ = WordpieceSpan( entire_wordpiece_span_.begin, std::min(entire_wordpiece_span_.begin + max_num_effective_wordpieces_, entire_wordpiece_span_.end)); previous_wordpiece_span_ = WordpieceSpan(-1, -1); } bool WindowGenerator::Next(VectorSpan *cur_wordpiece_indices, VectorSpan *cur_token_starts, VectorSpan *cur_tokens) { if (Done()) { return false; } // Update the span to cover full tokens. int cur_first_token_index, cur_num_tokens; next_wordpiece_span_ = internal::FindFullTokensSpanInWindow( *token_starts_, next_wordpiece_span_, max_num_effective_wordpieces_, wordpiece_indices_->size(), &cur_first_token_index, &cur_num_tokens); *cur_token_starts = VectorSpan( token_starts_->begin() + cur_first_token_index, token_starts_->begin() + cur_first_token_index + cur_num_tokens); *cur_tokens = VectorSpan( tokens_->begin() + cur_first_token_index, tokens_->begin() + cur_first_token_index + cur_num_tokens); // Handle the edge case where the tokens are composed of many wordpieces and // the window doesn't advance. if (next_wordpiece_span_.begin <= previous_wordpiece_span_.begin || next_wordpiece_span_.end <= previous_wordpiece_span_.end) { return false; } previous_wordpiece_span_ = next_wordpiece_span_; int next_wordpiece_first = std::max( previous_wordpiece_span_.end - sliding_window_num_wordpieces_overlap_, previous_wordpiece_span_.begin + 1); next_wordpiece_span_ = WordpieceSpan( next_wordpiece_first, std::min(next_wordpiece_first + max_num_effective_wordpieces_, entire_wordpiece_span_.end)); *cur_wordpiece_indices = VectorSpan( wordpiece_indices_->begin() + previous_wordpiece_span_.begin, wordpiece_indices_->begin() + previous_wordpiece_span_.begin + previous_wordpiece_span_.length()); return true; } bool ConvertTagsToAnnotatedSpans(const VectorSpan &tokens, const std::vector &tags, const std::vector &label_filter, bool relaxed_inside_label_matching, bool relaxed_label_category_matching, float priority_score, std::vector *results) { AnnotatedSpan current_span; std::string current_tag_type; if (tags.size() > tokens.size()) { return false; } for (int i = 0; i < tags.size(); i++) { if (tags[i].empty()) { return false; } std::vector tag_parts = absl::StrSplit(tags[i], '-'); TC3_CHECK_GT(tag_parts.size(), 0); if (tag_parts[0].size() != 1) { return false; } std::string tag_type = ""; if (tag_parts.size() > 2) { // Skip if the current label doesn't match the filter. if (!StrIsOneOf(std::string(tag_parts[1]), label_filter)) { current_tag_type = ""; current_span = {}; continue; } // Relax the matching of the label category if specified. tag_type = relaxed_label_category_matching ? std::string(tag_parts[2]) : absl::StrCat(tag_parts[1], "-", tag_parts[2]); } switch (tag_parts[0][0]) { case 'S': { if (tag_parts.size() != 3) { return false; } current_span = {}; current_tag_type = ""; results->push_back(AnnotatedSpan{ {tokens[i].start, tokens[i].end}, {{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]), /*arg_score=*/1.0, priority_score}}}); break; }; case 'B': { if (tag_parts.size() != 3) { return false; } current_tag_type = tag_type; current_span = {}; current_span.classification.push_back( {/*arg_collection=*/SaftLabelToCollection(tag_parts[2]), /*arg_score=*/1.0, priority_score}); current_span.span.first = tokens[i].start; break; }; case 'I': { if (tag_parts.size() != 3) { return false; } if (!relaxed_inside_label_matching && current_tag_type != tag_type) { current_tag_type = ""; current_span = {}; } break; } case 'E': { if (tag_parts.size() != 3) { return false; } if (!current_tag_type.empty() && current_tag_type == tag_type) { current_span.span.second = tokens[i].end; results->push_back(current_span); current_span = {}; current_tag_type = ""; } break; }; case 'O': { current_tag_type = ""; current_span = {}; break; }; default: { TC3_LOG(ERROR) << "Unrecognized tag: " << tags[i]; return false; } } } return true; } using PodNerModel_::CollectionT; using PodNerModel_::LabelT; using PodNerModel_::Label_::BoiseType; using PodNerModel_::Label_::MentionType; bool ConvertTagsToAnnotatedSpans(const VectorSpan &tokens, const std::vector &labels, const std::vector &collections, const std::vector &mention_filter, bool relaxed_inside_label_matching, bool relaxed_mention_type_matching, std::vector *results) { if (labels.size() > tokens.size()) { return false; } AnnotatedSpan current_span; std::string current_collection_name = ""; for (int i = 0; i < labels.size(); i++) { const LabelT &label = labels[i]; if (label.collection_id < 0 || label.collection_id >= collections.size()) { return false; } if (std::find(mention_filter.begin(), mention_filter.end(), label.mention_type) == mention_filter.end()) { // Skip if the current label doesn't match the filter. current_span = {}; current_collection_name = ""; continue; } switch (label.boise_type) { case BoiseType::BoiseType_SINGLE: { current_span = {}; current_collection_name = ""; results->push_back(AnnotatedSpan{ {tokens[i].start, tokens[i].end}, {{/*arg_collection=*/collections[label.collection_id].name, /*arg_score=*/1.0, collections[label.collection_id].single_token_priority_score}}}); break; }; case BoiseType::BoiseType_BEGIN: { current_span = {}; current_span.classification.push_back( {/*arg_collection=*/collections[label.collection_id].name, /*arg_score=*/1.0, collections[label.collection_id].multi_token_priority_score}); current_span.span.first = tokens[i].start; current_collection_name = collections[label.collection_id].name; break; }; case BoiseType::BoiseType_INTERMEDIATE: { if (current_collection_name.empty() || (!relaxed_mention_type_matching && labels[i - 1].mention_type != label.mention_type) || (!relaxed_inside_label_matching && labels[i - 1].collection_id != label.collection_id)) { current_span = {}; current_collection_name = ""; } break; } case BoiseType::BoiseType_END: { if (!current_collection_name.empty() && current_collection_name == collections[label.collection_id].name && (relaxed_mention_type_matching || labels[i - 1].mention_type == label.mention_type)) { current_span.span.second = tokens[i].end; results->push_back(current_span); } current_span = {}; current_collection_name = ""; break; }; case BoiseType::BoiseType_O: { current_span = {}; current_collection_name = ""; break; }; default: { TC3_LOG(ERROR) << "Unrecognized tag: " << labels[i].boise_type; return false; } } } return true; } bool MergeLabelsIntoLeftSequence( const std::vector &labels_right, int index_first_right_tag_in_left, std::vector *labels_left) { if (index_first_right_tag_in_left > labels_left->size()) { return false; } int overlaping_from_left = (labels_left->size() - index_first_right_tag_in_left) / 2; labels_left->resize(index_first_right_tag_in_left + labels_right.size()); std::copy(labels_right.begin() + overlaping_from_left, labels_right.end(), labels_left->begin() + index_first_right_tag_in_left + overlaping_from_left); return true; } } // namespace libtextclassifier3