1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
18 #define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
19 
20 #include <string>
21 #include <vector>
22 
23 #include "annotator/model_generated.h"
24 #include "annotator/types.h"
25 #include "absl/strings/string_view.h"
26 
27 namespace libtextclassifier3 {
28 // Converts saft labels like /saft/person to collection name 'person'.
29 std::string SaftLabelToCollection(absl::string_view saft_label);
30 
31 struct WordpieceSpan {
32   // Beginning index is inclusive, end index is exclusive.
WordpieceSpanWordpieceSpan33   WordpieceSpan() : begin(0), end(0) {}
WordpieceSpanWordpieceSpan34   WordpieceSpan(int begin, int end) : begin(begin), end(end) {}
35   int begin;
36   int end;
37   bool operator==(const WordpieceSpan &other) const {
38     return this->begin == other.begin && this->end == other.end;
39   }
lengthWordpieceSpan40   int length() { return end - begin; }
41 };
42 
43 namespace internal {
44 // Finds the wordpiece window arond the given span_of_interest. If the number
45 // of wordpieces in this window is smaller than max_num_wordpieces_in_window
46 // it is expanded around the span of interest.
47 WordpieceSpan FindWordpiecesWindowAroundSpan(
48     const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
49     const std::vector<int32_t> &word_starts, int num_wordpieces,
50     int max_num_wordpieces_in_window);
51 // Expands the given wordpiece window around the given window to the be
52 // maximal possible while making sure it includes only full tokens.
53 WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
54                                    int num_wordpieces,
55                                    WordpieceSpan wordpiece_span_to_expand);
56 // Returns the index of the last token which ends before wordpiece_end.
57 int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
58                            int num_wordpieces, int wordpiece_end);
59 // Returns the index of the token which includes first_wordpiece_index.
60 int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
61                             int first_wordpiece_index);
62 // Given wordpiece_span, and max_num_wordpieces, finds:
63 //   1. The first token which includes wordpiece_span.begin.
64 //   2. The length of tokens sequence which starts from this token and:
65 //      a. Its last token's last wordpiece index ends before wordpiece_span.end.
66 //      b. Its overall number of wordpieces is at most max_num_wordpieces.
67 // Returns the updated wordpiece_span: begin and end wordpieces of this token
68 // sequence.
69 WordpieceSpan FindFullTokensSpanInWindow(
70     const std::vector<int32_t> &word_starts,
71     const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
72     int num_wordpieces, int *first_token_index, int *num_tokens);
73 
74 }  // namespace internal
75 // Converts sequence of IOB tags to AnnotatedSpans. Ignores illegal sequences.
76 // Setting label_filter can also help ignore certain label tags like "NAM" or
77 // "NOM".
78 // The inside tag can be ignored when setting relaxed_inside_label_matching,
79 // e.g. B-NAM-location, I-NAM-other, E-NAM-location would be considered a valid
80 // sequence.
81 // The label category matching can be ignored when setting
82 // relaxed_label_category_matching. The matching will only operate at the entity
83 // level, e.g. B-NAM-location, E-NOM-location would be considered a valid
84 // sequence.
85 bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
86                                  const std::vector<std::string> &tags,
87                                  const std::vector<std::string> &label_filter,
88                                  bool relaxed_inside_label_matching,
89                                  bool relaxed_label_category_matching,
90                                  float priority_score,
91                                  std::vector<AnnotatedSpan> *results);
92 
93 // Like the previous function but instead of getting the tags as strings
94 // the input is PodNerModel_::LabelT along with the collections vector which
95 // hold the collection name and priorities. e.g. a tag was "B-NAM-location" and
96 // the priority_score was 1.0 it would be Label(BoiseType_BEGIN,
97 // MentionType_NAM, 1) and collections={{"xxx", 1., 1.},
98 // {"location", 1., 1.}, {"yyy", 1., 1.}, ...}.
99 bool ConvertTagsToAnnotatedSpans(
100     const VectorSpan<Token> &tokens,
101     const std::vector<PodNerModel_::LabelT> &labels,
102     const std::vector<PodNerModel_::CollectionT> &collections,
103     const std::vector<PodNerModel_::Label_::MentionType> &mention_filter,
104     bool relaxed_inside_label_matching, bool relaxed_mention_type_matching,
105     std::vector<AnnotatedSpan> *results);
106 
107 // Merge two overlaping sequences of labels, the result is placed into the left
108 // sequence. In the overlapping part takes the labels from the left sequence on
109 // the first half and from the right on the second half.
110 bool MergeLabelsIntoLeftSequence(
111     const std::vector<PodNerModel_::LabelT> &labels_right,
112     int index_first_right_tag_in_left,
113     std::vector<PodNerModel_::LabelT> *labels_left);
114 
115 // This class is used to slide over {wordpiece_indices, token_starts, tokens} in
116 // windows of at most max_num_wordpieces while assuring that each window
117 // contains only full tokens.
118 class WindowGenerator {
119  public:
120   WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
121                   const std::vector<int32_t> &token_starts,
122                   const std::vector<Token> &tokens, int max_num_wordpieces,
123                   int sliding_window_overlap,
124                   const CodepointSpan &span_of_interest);
125 
126   bool Next(VectorSpan<int32_t> *cur_wordpiece_indices,
127             VectorSpan<int32_t> *cur_token_starts,
128             VectorSpan<Token> *cur_tokens);
129 
Done()130   bool Done() const {
131     return previous_wordpiece_span_.end >= entire_wordpiece_span_.end;
132   }
133 
134  private:
135   const std::vector<int32_t> *wordpiece_indices_;
136   const std::vector<int32_t> *token_starts_;
137   const std::vector<Token> *tokens_;
138   int max_num_effective_wordpieces_;
139   int sliding_window_num_wordpieces_overlap_;
140   WordpieceSpan entire_wordpiece_span_;
141   WordpieceSpan next_wordpiece_span_;
142   WordpieceSpan previous_wordpiece_span_;
143 };
144 
145 }  // namespace libtextclassifier3
146 
147 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
148