1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_ 18 #define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_ 19 20 #include <string> 21 #include <vector> 22 23 #include "annotator/model_generated.h" 24 #include "annotator/types.h" 25 #include "absl/strings/string_view.h" 26 27 namespace libtextclassifier3 { 28 // Converts saft labels like /saft/person to collection name 'person'. 29 std::string SaftLabelToCollection(absl::string_view saft_label); 30 31 struct WordpieceSpan { 32 // Beginning index is inclusive, end index is exclusive. WordpieceSpanWordpieceSpan33 WordpieceSpan() : begin(0), end(0) {} WordpieceSpanWordpieceSpan34 WordpieceSpan(int begin, int end) : begin(begin), end(end) {} 35 int begin; 36 int end; 37 bool operator==(const WordpieceSpan &other) const { 38 return this->begin == other.begin && this->end == other.end; 39 } lengthWordpieceSpan40 int length() { return end - begin; } 41 }; 42 43 namespace internal { 44 // Finds the wordpiece window arond the given span_of_interest. If the number 45 // of wordpieces in this window is smaller than max_num_wordpieces_in_window 46 // it is expanded around the span of interest. 47 WordpieceSpan FindWordpiecesWindowAroundSpan( 48 const CodepointSpan &span_of_interest, const std::vector<Token> &tokens, 49 const std::vector<int32_t> &word_starts, int num_wordpieces, 50 int max_num_wordpieces_in_window); 51 // Expands the given wordpiece window around the given window to the be 52 // maximal possible while making sure it includes only full tokens. 53 WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window, 54 int num_wordpieces, 55 WordpieceSpan wordpiece_span_to_expand); 56 // Returns the index of the last token which ends before wordpiece_end. 57 int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts, 58 int num_wordpieces, int wordpiece_end); 59 // Returns the index of the token which includes first_wordpiece_index. 60 int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts, 61 int first_wordpiece_index); 62 // Given wordpiece_span, and max_num_wordpieces, finds: 63 // 1. The first token which includes wordpiece_span.begin. 64 // 2. The length of tokens sequence which starts from this token and: 65 // a. Its last token's last wordpiece index ends before wordpiece_span.end. 66 // b. Its overall number of wordpieces is at most max_num_wordpieces. 67 // Returns the updated wordpiece_span: begin and end wordpieces of this token 68 // sequence. 69 WordpieceSpan FindFullTokensSpanInWindow( 70 const std::vector<int32_t> &word_starts, 71 const WordpieceSpan &wordpiece_span, int max_num_wordpieces, 72 int num_wordpieces, int *first_token_index, int *num_tokens); 73 74 } // namespace internal 75 // Converts sequence of IOB tags to AnnotatedSpans. Ignores illegal sequences. 76 // Setting label_filter can also help ignore certain label tags like "NAM" or 77 // "NOM". 78 // The inside tag can be ignored when setting relaxed_inside_label_matching, 79 // e.g. B-NAM-location, I-NAM-other, E-NAM-location would be considered a valid 80 // sequence. 81 // The label category matching can be ignored when setting 82 // relaxed_label_category_matching. The matching will only operate at the entity 83 // level, e.g. B-NAM-location, E-NOM-location would be considered a valid 84 // sequence. 85 bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens, 86 const std::vector<std::string> &tags, 87 const std::vector<std::string> &label_filter, 88 bool relaxed_inside_label_matching, 89 bool relaxed_label_category_matching, 90 float priority_score, 91 std::vector<AnnotatedSpan> *results); 92 93 // Like the previous function but instead of getting the tags as strings 94 // the input is PodNerModel_::LabelT along with the collections vector which 95 // hold the collection name and priorities. e.g. a tag was "B-NAM-location" and 96 // the priority_score was 1.0 it would be Label(BoiseType_BEGIN, 97 // MentionType_NAM, 1) and collections={{"xxx", 1., 1.}, 98 // {"location", 1., 1.}, {"yyy", 1., 1.}, ...}. 99 bool ConvertTagsToAnnotatedSpans( 100 const VectorSpan<Token> &tokens, 101 const std::vector<PodNerModel_::LabelT> &labels, 102 const std::vector<PodNerModel_::CollectionT> &collections, 103 const std::vector<PodNerModel_::Label_::MentionType> &mention_filter, 104 bool relaxed_inside_label_matching, bool relaxed_mention_type_matching, 105 std::vector<AnnotatedSpan> *results); 106 107 // Merge two overlaping sequences of labels, the result is placed into the left 108 // sequence. In the overlapping part takes the labels from the left sequence on 109 // the first half and from the right on the second half. 110 bool MergeLabelsIntoLeftSequence( 111 const std::vector<PodNerModel_::LabelT> &labels_right, 112 int index_first_right_tag_in_left, 113 std::vector<PodNerModel_::LabelT> *labels_left); 114 115 // This class is used to slide over {wordpiece_indices, token_starts, tokens} in 116 // windows of at most max_num_wordpieces while assuring that each window 117 // contains only full tokens. 118 class WindowGenerator { 119 public: 120 WindowGenerator(const std::vector<int32_t> &wordpiece_indices, 121 const std::vector<int32_t> &token_starts, 122 const std::vector<Token> &tokens, int max_num_wordpieces, 123 int sliding_window_overlap, 124 const CodepointSpan &span_of_interest); 125 126 bool Next(VectorSpan<int32_t> *cur_wordpiece_indices, 127 VectorSpan<int32_t> *cur_token_starts, 128 VectorSpan<Token> *cur_tokens); 129 Done()130 bool Done() const { 131 return previous_wordpiece_span_.end >= entire_wordpiece_span_.end; 132 } 133 134 private: 135 const std::vector<int32_t> *wordpiece_indices_; 136 const std::vector<int32_t> *token_starts_; 137 const std::vector<Token> *tokens_; 138 int max_num_effective_wordpieces_; 139 int sliding_window_num_wordpieces_overlap_; 140 WordpieceSpan entire_wordpiece_span_; 141 WordpieceSpan next_wordpiece_span_; 142 WordpieceSpan previous_wordpiece_span_; 143 }; 144 145 } // namespace libtextclassifier3 146 147 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_ 148