1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/pod_ner/utils.h"
18
19 #include <algorithm>
20 #include <iostream>
21 #include <unordered_map>
22
23 #include "annotator/model_generated.h"
24 #include "annotator/types.h"
25 #include "utils/base/logging.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_split.h"
28
29 namespace libtextclassifier3 {
30 namespace {
31
32 // Returns true if the needle string is contained in the haystack.
StrIsOneOf(const std::string & needle,const std::vector<std::string> & haystack)33 bool StrIsOneOf(const std::string &needle,
34 const std::vector<std::string> &haystack) {
35 return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
36 }
37
38 // Finds the wordpiece span of the tokens in the given span.
CodepointSpanToWordpieceSpan(const CodepointSpan & span,const std::vector<Token> & tokens,const std::vector<int32_t> & word_starts,int num_wordpieces)39 WordpieceSpan CodepointSpanToWordpieceSpan(
40 const CodepointSpan &span, const std::vector<Token> &tokens,
41 const std::vector<int32_t> &word_starts, int num_wordpieces) {
42 int span_first_wordpiece_index = 0;
43 int span_last_wordpiece_index = num_wordpieces;
44 for (int i = 0; i < tokens.size(); i++) {
45 if (tokens[i].start <= span.first && span.first < tokens[i].end) {
46 span_first_wordpiece_index = word_starts[i];
47 }
48 if (tokens[i].start <= span.second && span.second <= tokens[i].end) {
49 span_last_wordpiece_index =
50 (i + 1) < word_starts.size() ? word_starts[i + 1] : num_wordpieces;
51 break;
52 }
53 }
54 return WordpieceSpan(span_first_wordpiece_index, span_last_wordpiece_index);
55 }
56
57 } // namespace
58
SaftLabelToCollection(absl::string_view saft_label)59 std::string SaftLabelToCollection(absl::string_view saft_label) {
60 return std::string(saft_label.substr(saft_label.rfind('/') + 1));
61 }
62
63 namespace internal {
64
FindLastFullTokenIndex(const std::vector<int32_t> & word_starts,int num_wordpieces,int wordpiece_end)65 int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
66 int num_wordpieces, int wordpiece_end) {
67 if (word_starts.empty()) {
68 return 0;
69 }
70 if (*word_starts.rbegin() < wordpiece_end &&
71 num_wordpieces <= wordpiece_end) {
72 // Last token.
73 return word_starts.size() - 1;
74 }
75 for (int i = word_starts.size() - 1; i > 0; --i) {
76 if (word_starts[i] <= wordpiece_end) {
77 return (i - 1);
78 }
79 }
80 return 0;
81 }
82
FindFirstFullTokenIndex(const std::vector<int32_t> & word_starts,int first_wordpiece_index)83 int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
84 int first_wordpiece_index) {
85 for (int i = 0; i < word_starts.size(); ++i) {
86 if (word_starts[i] == first_wordpiece_index) {
87 return i;
88 } else if (word_starts[i] > first_wordpiece_index) {
89 return std::max(0, i - 1);
90 }
91 }
92
93 return std::max(0, static_cast<int>(word_starts.size()) - 1);
94 }
95
ExpandWindowAndAlign(int max_num_wordpieces_in_window,int num_wordpieces,WordpieceSpan wordpiece_span_to_expand)96 WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
97 int num_wordpieces,
98 WordpieceSpan wordpiece_span_to_expand) {
99 if (wordpiece_span_to_expand.length() >= max_num_wordpieces_in_window) {
100 return wordpiece_span_to_expand;
101 }
102 int window_first_wordpiece_index = std::max(
103 0, wordpiece_span_to_expand.begin - ((max_num_wordpieces_in_window -
104 wordpiece_span_to_expand.length()) /
105 2));
106 if ((window_first_wordpiece_index + max_num_wordpieces_in_window) >
107 num_wordpieces) {
108 window_first_wordpiece_index =
109 std::max(num_wordpieces - max_num_wordpieces_in_window, 0);
110 }
111 return WordpieceSpan(
112 window_first_wordpiece_index,
113 std::min(window_first_wordpiece_index + max_num_wordpieces_in_window,
114 num_wordpieces));
115 }
116
FindWordpiecesWindowAroundSpan(const CodepointSpan & span_of_interest,const std::vector<Token> & tokens,const std::vector<int32_t> & word_starts,int num_wordpieces,int max_num_wordpieces_in_window)117 WordpieceSpan FindWordpiecesWindowAroundSpan(
118 const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
119 const std::vector<int32_t> &word_starts, int num_wordpieces,
120 int max_num_wordpieces_in_window) {
121 WordpieceSpan wordpiece_span_to_expand = CodepointSpanToWordpieceSpan(
122 span_of_interest, tokens, word_starts, num_wordpieces);
123 WordpieceSpan max_wordpiece_span = ExpandWindowAndAlign(
124 max_num_wordpieces_in_window, num_wordpieces, wordpiece_span_to_expand);
125 return max_wordpiece_span;
126 }
127
FindFullTokensSpanInWindow(const std::vector<int32_t> & word_starts,const WordpieceSpan & wordpiece_span,int max_num_wordpieces,int num_wordpieces,int * first_token_index,int * num_tokens)128 WordpieceSpan FindFullTokensSpanInWindow(
129 const std::vector<int32_t> &word_starts,
130 const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
131 int num_wordpieces, int *first_token_index, int *num_tokens) {
132 int window_first_wordpiece_index = wordpiece_span.begin;
133 *first_token_index = internal::FindFirstFullTokenIndex(
134 word_starts, window_first_wordpiece_index);
135 window_first_wordpiece_index = word_starts[*first_token_index];
136
137 // Need to update the last index in case the first moved backward.
138 int wordpiece_window_end = std::min(
139 wordpiece_span.end, window_first_wordpiece_index + max_num_wordpieces);
140 int last_token_index;
141 last_token_index = internal::FindLastFullTokenIndex(
142 word_starts, num_wordpieces, wordpiece_window_end);
143 wordpiece_window_end = last_token_index == (word_starts.size() - 1)
144 ? num_wordpieces
145 : word_starts[last_token_index + 1];
146
147 *num_tokens = last_token_index - *first_token_index + 1;
148 return WordpieceSpan(window_first_wordpiece_index, wordpiece_window_end);
149 }
150
151 } // namespace internal
152
WindowGenerator(const std::vector<int32_t> & wordpiece_indices,const std::vector<int32_t> & token_starts,const std::vector<Token> & tokens,int max_num_wordpieces,int sliding_window_overlap,const CodepointSpan & span_of_interest)153 WindowGenerator::WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
154 const std::vector<int32_t> &token_starts,
155 const std::vector<Token> &tokens,
156 int max_num_wordpieces,
157 int sliding_window_overlap,
158 const CodepointSpan &span_of_interest)
159 : wordpiece_indices_(&wordpiece_indices),
160 token_starts_(&token_starts),
161 tokens_(&tokens),
162 max_num_effective_wordpieces_(max_num_wordpieces),
163 sliding_window_num_wordpieces_overlap_(sliding_window_overlap) {
164 entire_wordpiece_span_ = internal::FindWordpiecesWindowAroundSpan(
165 span_of_interest, tokens, token_starts, wordpiece_indices.size(),
166 max_num_wordpieces);
167 next_wordpiece_span_ = WordpieceSpan(
168 entire_wordpiece_span_.begin,
169 std::min(entire_wordpiece_span_.begin + max_num_effective_wordpieces_,
170 entire_wordpiece_span_.end));
171 previous_wordpiece_span_ = WordpieceSpan(-1, -1);
172 }
173
Next(VectorSpan<int32_t> * cur_wordpiece_indices,VectorSpan<int32_t> * cur_token_starts,VectorSpan<Token> * cur_tokens)174 bool WindowGenerator::Next(VectorSpan<int32_t> *cur_wordpiece_indices,
175 VectorSpan<int32_t> *cur_token_starts,
176 VectorSpan<Token> *cur_tokens) {
177 if (Done()) {
178 return false;
179 }
180 // Update the span to cover full tokens.
181 int cur_first_token_index, cur_num_tokens;
182 next_wordpiece_span_ = internal::FindFullTokensSpanInWindow(
183 *token_starts_, next_wordpiece_span_, max_num_effective_wordpieces_,
184 wordpiece_indices_->size(), &cur_first_token_index, &cur_num_tokens);
185 *cur_token_starts = VectorSpan<int32_t>(
186 token_starts_->begin() + cur_first_token_index,
187 token_starts_->begin() + cur_first_token_index + cur_num_tokens);
188 *cur_tokens = VectorSpan<Token>(
189 tokens_->begin() + cur_first_token_index,
190 tokens_->begin() + cur_first_token_index + cur_num_tokens);
191
192 // Handle the edge case where the tokens are composed of many wordpieces and
193 // the window doesn't advance.
194 if (next_wordpiece_span_.begin <= previous_wordpiece_span_.begin ||
195 next_wordpiece_span_.end <= previous_wordpiece_span_.end) {
196 return false;
197 }
198 previous_wordpiece_span_ = next_wordpiece_span_;
199
200 int next_wordpiece_first = std::max(
201 previous_wordpiece_span_.end - sliding_window_num_wordpieces_overlap_,
202 previous_wordpiece_span_.begin + 1);
203 next_wordpiece_span_ = WordpieceSpan(
204 next_wordpiece_first,
205 std::min(next_wordpiece_first + max_num_effective_wordpieces_,
206 entire_wordpiece_span_.end));
207
208 *cur_wordpiece_indices = VectorSpan<int>(
209 wordpiece_indices_->begin() + previous_wordpiece_span_.begin,
210 wordpiece_indices_->begin() + previous_wordpiece_span_.begin +
211 previous_wordpiece_span_.length());
212
213 return true;
214 }
215
ConvertTagsToAnnotatedSpans(const VectorSpan<Token> & tokens,const std::vector<std::string> & tags,const std::vector<std::string> & label_filter,bool relaxed_inside_label_matching,bool relaxed_label_category_matching,float priority_score,std::vector<AnnotatedSpan> * results)216 bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
217 const std::vector<std::string> &tags,
218 const std::vector<std::string> &label_filter,
219 bool relaxed_inside_label_matching,
220 bool relaxed_label_category_matching,
221 float priority_score,
222 std::vector<AnnotatedSpan> *results) {
223 AnnotatedSpan current_span;
224 std::string current_tag_type;
225 if (tags.size() > tokens.size()) {
226 return false;
227 }
228 for (int i = 0; i < tags.size(); i++) {
229 if (tags[i].empty()) {
230 return false;
231 }
232
233 std::vector<absl::string_view> tag_parts = absl::StrSplit(tags[i], '-');
234 TC3_CHECK_GT(tag_parts.size(), 0);
235 if (tag_parts[0].size() != 1) {
236 return false;
237 }
238
239 std::string tag_type = "";
240 if (tag_parts.size() > 2) {
241 // Skip if the current label doesn't match the filter.
242 if (!StrIsOneOf(std::string(tag_parts[1]), label_filter)) {
243 current_tag_type = "";
244 current_span = {};
245 continue;
246 }
247
248 // Relax the matching of the label category if specified.
249 tag_type = relaxed_label_category_matching
250 ? std::string(tag_parts[2])
251 : absl::StrCat(tag_parts[1], "-", tag_parts[2]);
252 }
253
254 switch (tag_parts[0][0]) {
255 case 'S': {
256 if (tag_parts.size() != 3) {
257 return false;
258 }
259
260 current_span = {};
261 current_tag_type = "";
262 results->push_back(AnnotatedSpan{
263 {tokens[i].start, tokens[i].end},
264 {{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
265 /*arg_score=*/1.0, priority_score}}});
266 break;
267 };
268
269 case 'B': {
270 if (tag_parts.size() != 3) {
271 return false;
272 }
273 current_tag_type = tag_type;
274 current_span = {};
275 current_span.classification.push_back(
276 {/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
277 /*arg_score=*/1.0, priority_score});
278 current_span.span.first = tokens[i].start;
279 break;
280 };
281
282 case 'I': {
283 if (tag_parts.size() != 3) {
284 return false;
285 }
286 if (!relaxed_inside_label_matching && current_tag_type != tag_type) {
287 current_tag_type = "";
288 current_span = {};
289 }
290 break;
291 }
292
293 case 'E': {
294 if (tag_parts.size() != 3) {
295 return false;
296 }
297 if (!current_tag_type.empty() && current_tag_type == tag_type) {
298 current_span.span.second = tokens[i].end;
299 results->push_back(current_span);
300 current_span = {};
301 current_tag_type = "";
302 }
303 break;
304 };
305
306 case 'O': {
307 current_tag_type = "";
308 current_span = {};
309 break;
310 };
311
312 default: {
313 TC3_LOG(ERROR) << "Unrecognized tag: " << tags[i];
314 return false;
315 }
316 }
317 }
318 return true;
319 }
320
321 using PodNerModel_::CollectionT;
322 using PodNerModel_::LabelT;
323 using PodNerModel_::Label_::BoiseType;
324 using PodNerModel_::Label_::MentionType;
325
ConvertTagsToAnnotatedSpans(const VectorSpan<Token> & tokens,const std::vector<LabelT> & labels,const std::vector<CollectionT> & collections,const std::vector<MentionType> & mention_filter,bool relaxed_inside_label_matching,bool relaxed_mention_type_matching,std::vector<AnnotatedSpan> * results)326 bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
327 const std::vector<LabelT> &labels,
328 const std::vector<CollectionT> &collections,
329 const std::vector<MentionType> &mention_filter,
330 bool relaxed_inside_label_matching,
331 bool relaxed_mention_type_matching,
332 std::vector<AnnotatedSpan> *results) {
333 if (labels.size() > tokens.size()) {
334 return false;
335 }
336
337 AnnotatedSpan current_span;
338 std::string current_collection_name = "";
339
340 for (int i = 0; i < labels.size(); i++) {
341 const LabelT &label = labels[i];
342
343 if (label.collection_id < 0 || label.collection_id >= collections.size()) {
344 return false;
345 }
346
347 if (std::find(mention_filter.begin(), mention_filter.end(),
348 label.mention_type) == mention_filter.end()) {
349 // Skip if the current label doesn't match the filter.
350 current_span = {};
351 current_collection_name = "";
352 continue;
353 }
354
355 switch (label.boise_type) {
356 case BoiseType::BoiseType_SINGLE: {
357 current_span = {};
358 current_collection_name = "";
359 results->push_back(AnnotatedSpan{
360 {tokens[i].start, tokens[i].end},
361 {{/*arg_collection=*/collections[label.collection_id].name,
362 /*arg_score=*/1.0,
363 collections[label.collection_id].single_token_priority_score}}});
364 break;
365 };
366
367 case BoiseType::BoiseType_BEGIN: {
368 current_span = {};
369 current_span.classification.push_back(
370 {/*arg_collection=*/collections[label.collection_id].name,
371 /*arg_score=*/1.0,
372 collections[label.collection_id].multi_token_priority_score});
373 current_span.span.first = tokens[i].start;
374 current_collection_name = collections[label.collection_id].name;
375 break;
376 };
377
378 case BoiseType::BoiseType_INTERMEDIATE: {
379 if (current_collection_name.empty() ||
380 (!relaxed_mention_type_matching &&
381 labels[i - 1].mention_type != label.mention_type) ||
382 (!relaxed_inside_label_matching &&
383 labels[i - 1].collection_id != label.collection_id)) {
384 current_span = {};
385 current_collection_name = "";
386 }
387 break;
388 }
389
390 case BoiseType::BoiseType_END: {
391 if (!current_collection_name.empty() &&
392 current_collection_name == collections[label.collection_id].name &&
393 (relaxed_mention_type_matching ||
394 labels[i - 1].mention_type == label.mention_type)) {
395 current_span.span.second = tokens[i].end;
396 results->push_back(current_span);
397 }
398 current_span = {};
399 current_collection_name = "";
400 break;
401 };
402
403 case BoiseType::BoiseType_O: {
404 current_span = {};
405 current_collection_name = "";
406 break;
407 };
408
409 default: {
410 TC3_LOG(ERROR) << "Unrecognized tag: " << labels[i].boise_type;
411 return false;
412 }
413 }
414 }
415 return true;
416 }
417
MergeLabelsIntoLeftSequence(const std::vector<PodNerModel_::LabelT> & labels_right,int index_first_right_tag_in_left,std::vector<PodNerModel_::LabelT> * labels_left)418 bool MergeLabelsIntoLeftSequence(
419 const std::vector<PodNerModel_::LabelT> &labels_right,
420 int index_first_right_tag_in_left,
421 std::vector<PodNerModel_::LabelT> *labels_left) {
422 if (index_first_right_tag_in_left > labels_left->size()) {
423 return false;
424 }
425
426 int overlaping_from_left =
427 (labels_left->size() - index_first_right_tag_in_left) / 2;
428
429 labels_left->resize(index_first_right_tag_in_left + labels_right.size());
430 std::copy(labels_right.begin() + overlaping_from_left, labels_right.end(),
431 labels_left->begin() + index_first_right_tag_in_left +
432 overlaping_from_left);
433 return true;
434 }
435
436 } // namespace libtextclassifier3
437