1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/pod_ner/utils.h"
18 
19 #include <algorithm>
20 #include <iostream>
21 #include <unordered_map>
22 
23 #include "annotator/model_generated.h"
24 #include "annotator/types.h"
25 #include "utils/base/logging.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_split.h"
28 
29 namespace libtextclassifier3 {
30 namespace {
31 
32 // Returns true if the needle string is contained in the haystack.
StrIsOneOf(const std::string & needle,const std::vector<std::string> & haystack)33 bool StrIsOneOf(const std::string &needle,
34                 const std::vector<std::string> &haystack) {
35   return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
36 }
37 
38 // Finds the wordpiece span of the tokens in the given span.
CodepointSpanToWordpieceSpan(const CodepointSpan & span,const std::vector<Token> & tokens,const std::vector<int32_t> & word_starts,int num_wordpieces)39 WordpieceSpan CodepointSpanToWordpieceSpan(
40     const CodepointSpan &span, const std::vector<Token> &tokens,
41     const std::vector<int32_t> &word_starts, int num_wordpieces) {
42   int span_first_wordpiece_index = 0;
43   int span_last_wordpiece_index = num_wordpieces;
44   for (int i = 0; i < tokens.size(); i++) {
45     if (tokens[i].start <= span.first && span.first < tokens[i].end) {
46       span_first_wordpiece_index = word_starts[i];
47     }
48     if (tokens[i].start <= span.second && span.second <= tokens[i].end) {
49       span_last_wordpiece_index =
50           (i + 1) < word_starts.size() ? word_starts[i + 1] : num_wordpieces;
51       break;
52     }
53   }
54   return WordpieceSpan(span_first_wordpiece_index, span_last_wordpiece_index);
55 }
56 
57 }  // namespace
58 
SaftLabelToCollection(absl::string_view saft_label)59 std::string SaftLabelToCollection(absl::string_view saft_label) {
60   return std::string(saft_label.substr(saft_label.rfind('/') + 1));
61 }
62 
63 namespace internal {
64 
FindLastFullTokenIndex(const std::vector<int32_t> & word_starts,int num_wordpieces,int wordpiece_end)65 int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
66                            int num_wordpieces, int wordpiece_end) {
67   if (word_starts.empty()) {
68     return 0;
69   }
70   if (*word_starts.rbegin() < wordpiece_end &&
71       num_wordpieces <= wordpiece_end) {
72     // Last token.
73     return word_starts.size() - 1;
74   }
75   for (int i = word_starts.size() - 1; i > 0; --i) {
76     if (word_starts[i] <= wordpiece_end) {
77       return (i - 1);
78     }
79   }
80   return 0;
81 }
82 
FindFirstFullTokenIndex(const std::vector<int32_t> & word_starts,int first_wordpiece_index)83 int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
84                             int first_wordpiece_index) {
85   for (int i = 0; i < word_starts.size(); ++i) {
86     if (word_starts[i] == first_wordpiece_index) {
87       return i;
88     } else if (word_starts[i] > first_wordpiece_index) {
89       return std::max(0, i - 1);
90     }
91   }
92 
93   return std::max(0, static_cast<int>(word_starts.size()) - 1);
94 }
95 
ExpandWindowAndAlign(int max_num_wordpieces_in_window,int num_wordpieces,WordpieceSpan wordpiece_span_to_expand)96 WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
97                                    int num_wordpieces,
98                                    WordpieceSpan wordpiece_span_to_expand) {
99   if (wordpiece_span_to_expand.length() >= max_num_wordpieces_in_window) {
100     return wordpiece_span_to_expand;
101   }
102   int window_first_wordpiece_index = std::max(
103       0, wordpiece_span_to_expand.begin - ((max_num_wordpieces_in_window -
104                                             wordpiece_span_to_expand.length()) /
105                                            2));
106   if ((window_first_wordpiece_index + max_num_wordpieces_in_window) >
107       num_wordpieces) {
108     window_first_wordpiece_index =
109         std::max(num_wordpieces - max_num_wordpieces_in_window, 0);
110   }
111   return WordpieceSpan(
112       window_first_wordpiece_index,
113       std::min(window_first_wordpiece_index + max_num_wordpieces_in_window,
114                num_wordpieces));
115 }
116 
FindWordpiecesWindowAroundSpan(const CodepointSpan & span_of_interest,const std::vector<Token> & tokens,const std::vector<int32_t> & word_starts,int num_wordpieces,int max_num_wordpieces_in_window)117 WordpieceSpan FindWordpiecesWindowAroundSpan(
118     const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
119     const std::vector<int32_t> &word_starts, int num_wordpieces,
120     int max_num_wordpieces_in_window) {
121   WordpieceSpan wordpiece_span_to_expand = CodepointSpanToWordpieceSpan(
122       span_of_interest, tokens, word_starts, num_wordpieces);
123   WordpieceSpan max_wordpiece_span = ExpandWindowAndAlign(
124       max_num_wordpieces_in_window, num_wordpieces, wordpiece_span_to_expand);
125   return max_wordpiece_span;
126 }
127 
FindFullTokensSpanInWindow(const std::vector<int32_t> & word_starts,const WordpieceSpan & wordpiece_span,int max_num_wordpieces,int num_wordpieces,int * first_token_index,int * num_tokens)128 WordpieceSpan FindFullTokensSpanInWindow(
129     const std::vector<int32_t> &word_starts,
130     const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
131     int num_wordpieces, int *first_token_index, int *num_tokens) {
132   int window_first_wordpiece_index = wordpiece_span.begin;
133   *first_token_index = internal::FindFirstFullTokenIndex(
134       word_starts, window_first_wordpiece_index);
135   window_first_wordpiece_index = word_starts[*first_token_index];
136 
137   // Need to update the last index in case the first moved backward.
138   int wordpiece_window_end = std::min(
139       wordpiece_span.end, window_first_wordpiece_index + max_num_wordpieces);
140   int last_token_index;
141   last_token_index = internal::FindLastFullTokenIndex(
142       word_starts, num_wordpieces, wordpiece_window_end);
143   wordpiece_window_end = last_token_index == (word_starts.size() - 1)
144                              ? num_wordpieces
145                              : word_starts[last_token_index + 1];
146 
147   *num_tokens = last_token_index - *first_token_index + 1;
148   return WordpieceSpan(window_first_wordpiece_index, wordpiece_window_end);
149 }
150 
151 }  // namespace internal
152 
WindowGenerator(const std::vector<int32_t> & wordpiece_indices,const std::vector<int32_t> & token_starts,const std::vector<Token> & tokens,int max_num_wordpieces,int sliding_window_overlap,const CodepointSpan & span_of_interest)153 WindowGenerator::WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
154                                  const std::vector<int32_t> &token_starts,
155                                  const std::vector<Token> &tokens,
156                                  int max_num_wordpieces,
157                                  int sliding_window_overlap,
158                                  const CodepointSpan &span_of_interest)
159     : wordpiece_indices_(&wordpiece_indices),
160       token_starts_(&token_starts),
161       tokens_(&tokens),
162       max_num_effective_wordpieces_(max_num_wordpieces),
163       sliding_window_num_wordpieces_overlap_(sliding_window_overlap) {
164   entire_wordpiece_span_ = internal::FindWordpiecesWindowAroundSpan(
165       span_of_interest, tokens, token_starts, wordpiece_indices.size(),
166       max_num_wordpieces);
167   next_wordpiece_span_ = WordpieceSpan(
168       entire_wordpiece_span_.begin,
169       std::min(entire_wordpiece_span_.begin + max_num_effective_wordpieces_,
170                entire_wordpiece_span_.end));
171   previous_wordpiece_span_ = WordpieceSpan(-1, -1);
172 }
173 
Next(VectorSpan<int32_t> * cur_wordpiece_indices,VectorSpan<int32_t> * cur_token_starts,VectorSpan<Token> * cur_tokens)174 bool WindowGenerator::Next(VectorSpan<int32_t> *cur_wordpiece_indices,
175                            VectorSpan<int32_t> *cur_token_starts,
176                            VectorSpan<Token> *cur_tokens) {
177   if (Done()) {
178     return false;
179   }
180   // Update the span to cover full tokens.
181   int cur_first_token_index, cur_num_tokens;
182   next_wordpiece_span_ = internal::FindFullTokensSpanInWindow(
183       *token_starts_, next_wordpiece_span_, max_num_effective_wordpieces_,
184       wordpiece_indices_->size(), &cur_first_token_index, &cur_num_tokens);
185   *cur_token_starts = VectorSpan<int32_t>(
186       token_starts_->begin() + cur_first_token_index,
187       token_starts_->begin() + cur_first_token_index + cur_num_tokens);
188   *cur_tokens = VectorSpan<Token>(
189       tokens_->begin() + cur_first_token_index,
190       tokens_->begin() + cur_first_token_index + cur_num_tokens);
191 
192   // Handle the edge case where the tokens are composed of many wordpieces and
193   // the window doesn't advance.
194   if (next_wordpiece_span_.begin <= previous_wordpiece_span_.begin ||
195       next_wordpiece_span_.end <= previous_wordpiece_span_.end) {
196     return false;
197   }
198   previous_wordpiece_span_ = next_wordpiece_span_;
199 
200   int next_wordpiece_first = std::max(
201       previous_wordpiece_span_.end - sliding_window_num_wordpieces_overlap_,
202       previous_wordpiece_span_.begin + 1);
203   next_wordpiece_span_ = WordpieceSpan(
204       next_wordpiece_first,
205       std::min(next_wordpiece_first + max_num_effective_wordpieces_,
206                entire_wordpiece_span_.end));
207 
208   *cur_wordpiece_indices = VectorSpan<int>(
209       wordpiece_indices_->begin() + previous_wordpiece_span_.begin,
210       wordpiece_indices_->begin() + previous_wordpiece_span_.begin +
211           previous_wordpiece_span_.length());
212 
213   return true;
214 }
215 
ConvertTagsToAnnotatedSpans(const VectorSpan<Token> & tokens,const std::vector<std::string> & tags,const std::vector<std::string> & label_filter,bool relaxed_inside_label_matching,bool relaxed_label_category_matching,float priority_score,std::vector<AnnotatedSpan> * results)216 bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
217                                  const std::vector<std::string> &tags,
218                                  const std::vector<std::string> &label_filter,
219                                  bool relaxed_inside_label_matching,
220                                  bool relaxed_label_category_matching,
221                                  float priority_score,
222                                  std::vector<AnnotatedSpan> *results) {
223   AnnotatedSpan current_span;
224   std::string current_tag_type;
225   if (tags.size() > tokens.size()) {
226     return false;
227   }
228   for (int i = 0; i < tags.size(); i++) {
229     if (tags[i].empty()) {
230       return false;
231     }
232 
233     std::vector<absl::string_view> tag_parts = absl::StrSplit(tags[i], '-');
234     TC3_CHECK_GT(tag_parts.size(), 0);
235     if (tag_parts[0].size() != 1) {
236       return false;
237     }
238 
239     std::string tag_type = "";
240     if (tag_parts.size() > 2) {
241       // Skip if the current label doesn't match the filter.
242       if (!StrIsOneOf(std::string(tag_parts[1]), label_filter)) {
243         current_tag_type = "";
244         current_span = {};
245         continue;
246       }
247 
248       // Relax the matching of the label category if specified.
249       tag_type = relaxed_label_category_matching
250                      ? std::string(tag_parts[2])
251                      : absl::StrCat(tag_parts[1], "-", tag_parts[2]);
252     }
253 
254     switch (tag_parts[0][0]) {
255       case 'S': {
256         if (tag_parts.size() != 3) {
257           return false;
258         }
259 
260         current_span = {};
261         current_tag_type = "";
262         results->push_back(AnnotatedSpan{
263             {tokens[i].start, tokens[i].end},
264             {{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
265               /*arg_score=*/1.0, priority_score}}});
266         break;
267       };
268 
269       case 'B': {
270         if (tag_parts.size() != 3) {
271           return false;
272         }
273         current_tag_type = tag_type;
274         current_span = {};
275         current_span.classification.push_back(
276             {/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
277              /*arg_score=*/1.0, priority_score});
278         current_span.span.first = tokens[i].start;
279         break;
280       };
281 
282       case 'I': {
283         if (tag_parts.size() != 3) {
284           return false;
285         }
286         if (!relaxed_inside_label_matching && current_tag_type != tag_type) {
287           current_tag_type = "";
288           current_span = {};
289         }
290         break;
291       }
292 
293       case 'E': {
294         if (tag_parts.size() != 3) {
295           return false;
296         }
297         if (!current_tag_type.empty() && current_tag_type == tag_type) {
298           current_span.span.second = tokens[i].end;
299           results->push_back(current_span);
300           current_span = {};
301           current_tag_type = "";
302         }
303         break;
304       };
305 
306       case 'O': {
307         current_tag_type = "";
308         current_span = {};
309         break;
310       };
311 
312       default: {
313         TC3_LOG(ERROR) << "Unrecognized tag: " << tags[i];
314         return false;
315       }
316     }
317   }
318   return true;
319 }
320 
321 using PodNerModel_::CollectionT;
322 using PodNerModel_::LabelT;
323 using PodNerModel_::Label_::BoiseType;
324 using PodNerModel_::Label_::MentionType;
325 
ConvertTagsToAnnotatedSpans(const VectorSpan<Token> & tokens,const std::vector<LabelT> & labels,const std::vector<CollectionT> & collections,const std::vector<MentionType> & mention_filter,bool relaxed_inside_label_matching,bool relaxed_mention_type_matching,std::vector<AnnotatedSpan> * results)326 bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
327                                  const std::vector<LabelT> &labels,
328                                  const std::vector<CollectionT> &collections,
329                                  const std::vector<MentionType> &mention_filter,
330                                  bool relaxed_inside_label_matching,
331                                  bool relaxed_mention_type_matching,
332                                  std::vector<AnnotatedSpan> *results) {
333   if (labels.size() > tokens.size()) {
334     return false;
335   }
336 
337   AnnotatedSpan current_span;
338   std::string current_collection_name = "";
339 
340   for (int i = 0; i < labels.size(); i++) {
341     const LabelT &label = labels[i];
342 
343     if (label.collection_id < 0 || label.collection_id >= collections.size()) {
344       return false;
345     }
346 
347     if (std::find(mention_filter.begin(), mention_filter.end(),
348                   label.mention_type) == mention_filter.end()) {
349       // Skip if the current label doesn't match the filter.
350       current_span = {};
351       current_collection_name = "";
352       continue;
353     }
354 
355     switch (label.boise_type) {
356       case BoiseType::BoiseType_SINGLE: {
357         current_span = {};
358         current_collection_name = "";
359         results->push_back(AnnotatedSpan{
360             {tokens[i].start, tokens[i].end},
361             {{/*arg_collection=*/collections[label.collection_id].name,
362               /*arg_score=*/1.0,
363               collections[label.collection_id].single_token_priority_score}}});
364         break;
365       };
366 
367       case BoiseType::BoiseType_BEGIN: {
368         current_span = {};
369         current_span.classification.push_back(
370             {/*arg_collection=*/collections[label.collection_id].name,
371              /*arg_score=*/1.0,
372              collections[label.collection_id].multi_token_priority_score});
373         current_span.span.first = tokens[i].start;
374         current_collection_name = collections[label.collection_id].name;
375         break;
376       };
377 
378       case BoiseType::BoiseType_INTERMEDIATE: {
379         if (current_collection_name.empty() ||
380             (!relaxed_mention_type_matching &&
381              labels[i - 1].mention_type != label.mention_type) ||
382             (!relaxed_inside_label_matching &&
383              labels[i - 1].collection_id != label.collection_id)) {
384           current_span = {};
385           current_collection_name = "";
386         }
387         break;
388       }
389 
390       case BoiseType::BoiseType_END: {
391         if (!current_collection_name.empty() &&
392             current_collection_name == collections[label.collection_id].name &&
393             (relaxed_mention_type_matching ||
394              labels[i - 1].mention_type == label.mention_type)) {
395           current_span.span.second = tokens[i].end;
396           results->push_back(current_span);
397         }
398         current_span = {};
399         current_collection_name = "";
400         break;
401       };
402 
403       case BoiseType::BoiseType_O: {
404         current_span = {};
405         current_collection_name = "";
406         break;
407       };
408 
409       default: {
410         TC3_LOG(ERROR) << "Unrecognized tag: " << labels[i].boise_type;
411         return false;
412       }
413     }
414   }
415   return true;
416 }
417 
MergeLabelsIntoLeftSequence(const std::vector<PodNerModel_::LabelT> & labels_right,int index_first_right_tag_in_left,std::vector<PodNerModel_::LabelT> * labels_left)418 bool MergeLabelsIntoLeftSequence(
419     const std::vector<PodNerModel_::LabelT> &labels_right,
420     int index_first_right_tag_in_left,
421     std::vector<PodNerModel_::LabelT> *labels_left) {
422   if (index_first_right_tag_in_left > labels_left->size()) {
423     return false;
424   }
425 
426   int overlaping_from_left =
427       (labels_left->size() - index_first_right_tag_in_left) / 2;
428 
429   labels_left->resize(index_first_right_tag_in_left + labels_right.size());
430   std::copy(labels_right.begin() + overlaping_from_left, labels_right.end(),
431             labels_left->begin() + index_first_right_tag_in_left +
432                 overlaping_from_left);
433   return true;
434 }
435 
436 }  // namespace libtextclassifier3
437