1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/feature-processor.h"
18 
19 #include <iterator>
20 #include <set>
21 #include <vector>
22 
23 #include "utils/base/logging.h"
24 #include "utils/strings/utf8.h"
25 #include "utils/utf8/unicodetext.h"
26 
27 namespace libtextclassifier3 {
28 
29 namespace internal {
30 
BuildTokenizer(const FeatureProcessorOptions * options,const UniLib * unilib)31 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
32                          const UniLib* unilib) {
33   std::vector<const TokenizationCodepointRange*> codepoint_config;
34   if (options->tokenization_codepoint_config() != nullptr) {
35     codepoint_config.insert(codepoint_config.end(),
36                             options->tokenization_codepoint_config()->begin(),
37                             options->tokenization_codepoint_config()->end());
38   }
39   std::vector<const CodepointRange*> internal_codepoint_config;
40   if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
41     internal_codepoint_config.insert(
42         internal_codepoint_config.end(),
43         options->internal_tokenizer_codepoint_ranges()->begin(),
44         options->internal_tokenizer_codepoint_ranges()->end());
45   }
46   const bool tokenize_on_script_change =
47       options->tokenization_codepoint_config() != nullptr &&
48       options->tokenize_on_script_change();
49   return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
50                    internal_codepoint_config, tokenize_on_script_change,
51                    options->icu_preserve_whitespace_tokens());
52 }
53 
BuildTokenFeatureExtractorOptions(const FeatureProcessorOptions * const options)54 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
55     const FeatureProcessorOptions* const options) {
56   TokenFeatureExtractorOptions extractor_options;
57 
58   extractor_options.num_buckets = options->num_buckets();
59   if (options->chargram_orders() != nullptr) {
60     for (int order : *options->chargram_orders()) {
61       extractor_options.chargram_orders.push_back(order);
62     }
63   }
64   extractor_options.max_word_length = options->max_word_length();
65   extractor_options.extract_case_feature = options->extract_case_feature();
66   extractor_options.unicode_aware_features = options->unicode_aware_features();
67   extractor_options.extract_selection_mask_feature =
68       options->extract_selection_mask_feature();
69   if (options->regexp_feature() != nullptr) {
70     for (const auto& regexp_feauture : *options->regexp_feature()) {
71       extractor_options.regexp_features.push_back(regexp_feauture->str());
72     }
73   }
74   extractor_options.remap_digits = options->remap_digits();
75   extractor_options.lowercase_tokens = options->lowercase_tokens();
76 
77   if (options->allowed_chargrams() != nullptr) {
78     for (const auto& chargram : *options->allowed_chargrams()) {
79       extractor_options.allowed_chargrams.insert(chargram->str());
80     }
81   }
82   return extractor_options;
83 }
84 
SplitTokensOnSelectionBoundaries(CodepointSpan selection,std::vector<Token> * tokens)85 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
86                                       std::vector<Token>* tokens) {
87   for (auto it = tokens->begin(); it != tokens->end(); ++it) {
88     const UnicodeText token_word =
89         UTF8ToUnicodeText(it->value, /*do_copy=*/false);
90 
91     auto last_start = token_word.begin();
92     int last_start_index = it->start;
93     std::vector<UnicodeText::const_iterator> split_points;
94 
95     // Selection start split point.
96     if (selection.first > it->start && selection.first < it->end) {
97       std::advance(last_start, selection.first - last_start_index);
98       split_points.push_back(last_start);
99       last_start_index = selection.first;
100     }
101 
102     // Selection end split point.
103     if (selection.second > it->start && selection.second < it->end) {
104       std::advance(last_start, selection.second - last_start_index);
105       split_points.push_back(last_start);
106     }
107 
108     if (!split_points.empty()) {
109       // Add a final split for the rest of the token unless it's been all
110       // consumed already.
111       if (split_points.back() != token_word.end()) {
112         split_points.push_back(token_word.end());
113       }
114 
115       std::vector<Token> replacement_tokens;
116       last_start = token_word.begin();
117       int current_pos = it->start;
118       for (const auto& split_point : split_points) {
119         Token new_token(token_word.UTF8Substring(last_start, split_point),
120                         current_pos,
121                         current_pos + std::distance(last_start, split_point));
122 
123         last_start = split_point;
124         current_pos = new_token.end;
125 
126         replacement_tokens.push_back(new_token);
127       }
128 
129       it = tokens->erase(it);
130       it = tokens->insert(it, replacement_tokens.begin(),
131                           replacement_tokens.end());
132       std::advance(it, replacement_tokens.size() - 1);
133     }
134   }
135 }
136 
137 }  // namespace internal
138 
StripTokensFromOtherLines(const std::string & context,CodepointSpan span,std::vector<Token> * tokens) const139 void FeatureProcessor::StripTokensFromOtherLines(
140     const std::string& context, CodepointSpan span,
141     std::vector<Token>* tokens) const {
142   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
143                                                         /*do_copy=*/false);
144   StripTokensFromOtherLines(context_unicode, span, tokens);
145 }
146 
StripTokensFromOtherLines(const UnicodeText & context_unicode,CodepointSpan span,std::vector<Token> * tokens) const147 void FeatureProcessor::StripTokensFromOtherLines(
148     const UnicodeText& context_unicode, CodepointSpan span,
149     std::vector<Token>* tokens) const {
150   std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
151 
152   auto span_start = context_unicode.begin();
153   if (span.first > 0) {
154     std::advance(span_start, span.first);
155   }
156   auto span_end = context_unicode.begin();
157   if (span.second > 0) {
158     std::advance(span_end, span.second);
159   }
160   for (const UnicodeTextRange& line : lines) {
161     // Find the line that completely contains the span.
162     if (line.first <= span_start && line.second >= span_end) {
163       const CodepointIndex last_line_begin_index =
164           std::distance(context_unicode.begin(), line.first);
165       const CodepointIndex last_line_end_index =
166           last_line_begin_index + std::distance(line.first, line.second);
167 
168       for (auto token = tokens->begin(); token != tokens->end();) {
169         if (token->start >= last_line_begin_index &&
170             token->end <= last_line_end_index) {
171           ++token;
172         } else {
173           token = tokens->erase(token);
174         }
175       }
176     }
177   }
178 }
179 
GetDefaultCollection() const180 std::string FeatureProcessor::GetDefaultCollection() const {
181   if (options_->default_collection() < 0 ||
182       options_->collections() == nullptr ||
183       options_->default_collection() >= options_->collections()->size()) {
184     TC3_LOG(ERROR)
185         << "Invalid or missing default collection. Returning empty string.";
186     return "";
187   }
188   return (*options_->collections())[options_->default_collection()]->str();
189 }
190 
Tokenize(const std::string & text) const191 std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
192   return tokenizer_.Tokenize(text);
193 }
194 
Tokenize(const UnicodeText & text_unicode) const195 std::vector<Token> FeatureProcessor::Tokenize(
196     const UnicodeText& text_unicode) const {
197   return tokenizer_.Tokenize(text_unicode);
198 }
199 
LabelToSpan(const int label,const VectorSpan<Token> & tokens,std::pair<CodepointIndex,CodepointIndex> * span) const200 bool FeatureProcessor::LabelToSpan(
201     const int label, const VectorSpan<Token>& tokens,
202     std::pair<CodepointIndex, CodepointIndex>* span) const {
203   if (tokens.size() != GetNumContextTokens()) {
204     return false;
205   }
206 
207   TokenSpan token_span;
208   if (!LabelToTokenSpan(label, &token_span)) {
209     return false;
210   }
211 
212   const int result_begin_token_index = token_span.first;
213   const Token& result_begin_token =
214       tokens[options_->context_size() - result_begin_token_index];
215   const int result_begin_codepoint = result_begin_token.start;
216   const int result_end_token_index = token_span.second;
217   const Token& result_end_token =
218       tokens[options_->context_size() + result_end_token_index];
219   const int result_end_codepoint = result_end_token.end;
220 
221   if (result_begin_codepoint == kInvalidIndex ||
222       result_end_codepoint == kInvalidIndex) {
223     *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
224   } else {
225     const UnicodeText token_begin_unicode =
226         UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
227     UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
228     const UnicodeText token_end_unicode =
229         UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
230     UnicodeText::const_iterator token_end = token_end_unicode.end();
231 
232     const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
233         token_begin, token_begin_unicode.end(),
234         /*count_from_beginning=*/true);
235     const int end_ignored =
236         CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
237                                            /*count_from_beginning=*/false);
238     // In case everything would be stripped, set the span to the original
239     // beginning and zero length.
240     if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
241       *span = {result_begin_codepoint, result_begin_codepoint};
242     } else {
243       *span = CodepointSpan({result_begin_codepoint + begin_ignored,
244                              result_end_codepoint - end_ignored});
245     }
246   }
247   return true;
248 }
249 
LabelToTokenSpan(const int label,TokenSpan * token_span) const250 bool FeatureProcessor::LabelToTokenSpan(const int label,
251                                         TokenSpan* token_span) const {
252   if (label >= 0 && label < label_to_selection_.size()) {
253     *token_span = label_to_selection_[label];
254     return true;
255   } else {
256     return false;
257   }
258 }
259 
SpanToLabel(const std::pair<CodepointIndex,CodepointIndex> & span,const std::vector<Token> & tokens,int * label) const260 bool FeatureProcessor::SpanToLabel(
261     const std::pair<CodepointIndex, CodepointIndex>& span,
262     const std::vector<Token>& tokens, int* label) const {
263   if (tokens.size() != GetNumContextTokens()) {
264     return false;
265   }
266 
267   const int click_position =
268       options_->context_size();  // Click is always in the middle.
269   const int padding = options_->context_size() - options_->max_selection_span();
270 
271   int span_left = 0;
272   for (int i = click_position - 1; i >= padding; i--) {
273     if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
274       ++span_left;
275     } else {
276       break;
277     }
278   }
279 
280   int span_right = 0;
281   for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
282     if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
283       ++span_right;
284     } else {
285       break;
286     }
287   }
288 
289   // Check that the spanned tokens cover the whole span.
290   bool tokens_match_span;
291   const CodepointIndex tokens_start = tokens[click_position - span_left].start;
292   const CodepointIndex tokens_end = tokens[click_position + span_right].end;
293   if (options_->snap_label_span_boundaries_to_containing_tokens()) {
294     tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
295   } else {
296     const UnicodeText token_left_unicode = UTF8ToUnicodeText(
297         tokens[click_position - span_left].value, /*do_copy=*/false);
298     const UnicodeText token_right_unicode = UTF8ToUnicodeText(
299         tokens[click_position + span_right].value, /*do_copy=*/false);
300 
301     UnicodeText::const_iterator span_begin = token_left_unicode.begin();
302     UnicodeText::const_iterator span_end = token_right_unicode.end();
303 
304     const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
305         span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
306     const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
307         token_right_unicode.begin(), span_end,
308         /*count_from_beginning=*/false);
309 
310     tokens_match_span = tokens_start <= span.first &&
311                         tokens_start + num_punctuation_start >= span.first &&
312                         tokens_end >= span.second &&
313                         tokens_end - num_punctuation_end <= span.second;
314   }
315 
316   if (tokens_match_span) {
317     *label = TokenSpanToLabel({span_left, span_right});
318   } else {
319     *label = kInvalidLabel;
320   }
321 
322   return true;
323 }
324 
TokenSpanToLabel(const TokenSpan & span) const325 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
326   auto it = selection_to_label_.find(span);
327   if (it != selection_to_label_.end()) {
328     return it->second;
329   } else {
330     return kInvalidLabel;
331   }
332 }
333 
CodepointSpanToTokenSpan(const std::vector<Token> & selectable_tokens,CodepointSpan codepoint_span,bool snap_boundaries_to_containing_tokens)334 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
335                                    CodepointSpan codepoint_span,
336                                    bool snap_boundaries_to_containing_tokens) {
337   const int codepoint_start = std::get<0>(codepoint_span);
338   const int codepoint_end = std::get<1>(codepoint_span);
339 
340   TokenIndex start_token = kInvalidIndex;
341   TokenIndex end_token = kInvalidIndex;
342   for (int i = 0; i < selectable_tokens.size(); ++i) {
343     bool is_token_in_span;
344     if (snap_boundaries_to_containing_tokens) {
345       is_token_in_span = codepoint_start < selectable_tokens[i].end &&
346                          codepoint_end > selectable_tokens[i].start;
347     } else {
348       is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
349                          codepoint_end >= selectable_tokens[i].end;
350     }
351     if (is_token_in_span && !selectable_tokens[i].is_padding) {
352       if (start_token == kInvalidIndex) {
353         start_token = i;
354       }
355       end_token = i + 1;
356     }
357   }
358   return {start_token, end_token};
359 }
360 
TokenSpanToCodepointSpan(const std::vector<Token> & selectable_tokens,TokenSpan token_span)361 CodepointSpan TokenSpanToCodepointSpan(
362     const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
363   return {selectable_tokens[token_span.first].start,
364           selectable_tokens[token_span.second - 1].end};
365 }
366 
367 namespace {
368 
369 // Finds a single token that completely contains the given span.
FindTokenThatContainsSpan(const std::vector<Token> & selectable_tokens,CodepointSpan codepoint_span)370 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
371                               CodepointSpan codepoint_span) {
372   const int codepoint_start = std::get<0>(codepoint_span);
373   const int codepoint_end = std::get<1>(codepoint_span);
374 
375   for (int i = 0; i < selectable_tokens.size(); ++i) {
376     if (codepoint_start >= selectable_tokens[i].start &&
377         codepoint_end <= selectable_tokens[i].end) {
378       return i;
379     }
380   }
381   return kInvalidIndex;
382 }
383 
384 }  // namespace
385 
386 namespace internal {
387 
CenterTokenFromClick(CodepointSpan span,const std::vector<Token> & selectable_tokens)388 int CenterTokenFromClick(CodepointSpan span,
389                          const std::vector<Token>& selectable_tokens) {
390   int range_begin;
391   int range_end;
392   std::tie(range_begin, range_end) =
393       CodepointSpanToTokenSpan(selectable_tokens, span);
394 
395   // If no exact match was found, try finding a token that completely contains
396   // the click span. This is useful e.g. when Android builds the selection
397   // using ICU tokenization, and ends up with only a portion of our space-
398   // separated token. E.g. for "(857)" Android would select "857".
399   if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
400     int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
401     if (token_index != kInvalidIndex) {
402       range_begin = token_index;
403       range_end = token_index + 1;
404     }
405   }
406 
407   // We only allow clicks that are exactly 1 selectable token.
408   if (range_end - range_begin == 1) {
409     return range_begin;
410   } else {
411     return kInvalidIndex;
412   }
413 }
414 
CenterTokenFromMiddleOfSelection(CodepointSpan span,const std::vector<Token> & selectable_tokens)415 int CenterTokenFromMiddleOfSelection(
416     CodepointSpan span, const std::vector<Token>& selectable_tokens) {
417   int range_begin;
418   int range_end;
419   std::tie(range_begin, range_end) =
420       CodepointSpanToTokenSpan(selectable_tokens, span);
421 
422   // Center the clicked token in the selection range.
423   if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
424     return (range_begin + range_end - 1) / 2;
425   } else {
426     return kInvalidIndex;
427   }
428 }
429 
430 }  // namespace internal
431 
FindCenterToken(CodepointSpan span,const std::vector<Token> & tokens) const432 int FeatureProcessor::FindCenterToken(CodepointSpan span,
433                                       const std::vector<Token>& tokens) const {
434   if (options_->center_token_selection_method() ==
435       FeatureProcessorOptions_::
436           CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
437     return internal::CenterTokenFromClick(span, tokens);
438   } else if (options_->center_token_selection_method() ==
439              FeatureProcessorOptions_::
440                  CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
441     return internal::CenterTokenFromMiddleOfSelection(span, tokens);
442   } else if (options_->center_token_selection_method() ==
443              FeatureProcessorOptions_::
444                  CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
445     // TODO(zilka): Remove once we have new models on the device.
446     // It uses the fact that sharing model use
447     // split_tokens_on_selection_boundaries and selection not. So depending on
448     // this we select the right way of finding the click location.
449     if (!options_->split_tokens_on_selection_boundaries()) {
450       // SmartSelection model.
451       return internal::CenterTokenFromClick(span, tokens);
452     } else {
453       // SmartSharing model.
454       return internal::CenterTokenFromMiddleOfSelection(span, tokens);
455     }
456   } else {
457     TC3_LOG(ERROR) << "Invalid center token selection method.";
458     return kInvalidIndex;
459   }
460 }
461 
SelectionLabelSpans(const VectorSpan<Token> tokens,std::vector<CodepointSpan> * selection_label_spans) const462 bool FeatureProcessor::SelectionLabelSpans(
463     const VectorSpan<Token> tokens,
464     std::vector<CodepointSpan>* selection_label_spans) const {
465   for (int i = 0; i < label_to_selection_.size(); ++i) {
466     CodepointSpan span;
467     if (!LabelToSpan(i, tokens, &span)) {
468       TC3_LOG(ERROR) << "Could not convert label to span: " << i;
469       return false;
470     }
471     selection_label_spans->push_back(span);
472   }
473   return true;
474 }
475 
PrepareIgnoredSpanBoundaryCodepoints()476 void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
477   if (options_->ignored_span_boundary_codepoints() != nullptr) {
478     for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
479       ignored_span_boundary_codepoints_.insert(codepoint);
480     }
481   }
482 }
483 
CountIgnoredSpanBoundaryCodepoints(const UnicodeText::const_iterator & span_start,const UnicodeText::const_iterator & span_end,bool count_from_beginning) const484 int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
485     const UnicodeText::const_iterator& span_start,
486     const UnicodeText::const_iterator& span_end,
487     bool count_from_beginning) const {
488   if (span_start == span_end) {
489     return 0;
490   }
491 
492   UnicodeText::const_iterator it;
493   UnicodeText::const_iterator it_last;
494   if (count_from_beginning) {
495     it = span_start;
496     it_last = span_end;
497     // We can assume that the string is non-zero length because of the check
498     // above, thus the decrement is always valid here.
499     --it_last;
500   } else {
501     it = span_end;
502     it_last = span_start;
503     // We can assume that the string is non-zero length because of the check
504     // above, thus the decrement is always valid here.
505     --it;
506   }
507 
508   // Move until we encounter a non-ignored character.
509   int num_ignored = 0;
510   while (ignored_span_boundary_codepoints_.find(*it) !=
511          ignored_span_boundary_codepoints_.end()) {
512     ++num_ignored;
513 
514     if (it == it_last) {
515       break;
516     }
517 
518     if (count_from_beginning) {
519       ++it;
520     } else {
521       --it;
522     }
523   }
524 
525   return num_ignored;
526 }
527 
528 namespace {
529 
FindSubstrings(const UnicodeText & t,const std::set<char32> & codepoints,std::vector<UnicodeTextRange> * ranges)530 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
531                     std::vector<UnicodeTextRange>* ranges) {
532   UnicodeText::const_iterator start = t.begin();
533   UnicodeText::const_iterator curr = start;
534   UnicodeText::const_iterator end = t.end();
535   for (; curr != end; ++curr) {
536     if (codepoints.find(*curr) != codepoints.end()) {
537       if (start != curr) {
538         ranges->push_back(std::make_pair(start, curr));
539       }
540       start = curr;
541       ++start;
542     }
543   }
544   if (start != end) {
545     ranges->push_back(std::make_pair(start, end));
546   }
547 }
548 
549 }  // namespace
550 
SplitContext(const UnicodeText & context_unicode) const551 std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
552     const UnicodeText& context_unicode) const {
553   std::vector<UnicodeTextRange> lines;
554   const std::set<char32> codepoints{{'\n', '|'}};
555   FindSubstrings(context_unicode, codepoints, &lines);
556   return lines;
557 }
558 
StripBoundaryCodepoints(const std::string & context,CodepointSpan span) const559 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
560     const std::string& context, CodepointSpan span) const {
561   const UnicodeText context_unicode =
562       UTF8ToUnicodeText(context, /*do_copy=*/false);
563   return StripBoundaryCodepoints(context_unicode, span);
564 }
565 
StripBoundaryCodepoints(const UnicodeText & context_unicode,CodepointSpan span) const566 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
567     const UnicodeText& context_unicode, CodepointSpan span) const {
568   if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
569     return span;
570   }
571 
572   UnicodeText::const_iterator span_begin = context_unicode.begin();
573   std::advance(span_begin, span.first);
574   UnicodeText::const_iterator span_end = context_unicode.begin();
575   std::advance(span_end, span.second);
576 
577   return StripBoundaryCodepoints(span_begin, span_end, span);
578 }
579 
StripBoundaryCodepoints(const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,CodepointSpan span) const580 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
581     const UnicodeText::const_iterator& span_begin,
582     const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
583   if (!ValidNonEmptySpan(span) || span_begin == span_end) {
584     return span;
585   }
586 
587   const int start_offset = CountIgnoredSpanBoundaryCodepoints(
588       span_begin, span_end, /*count_from_beginning=*/true);
589   const int end_offset = CountIgnoredSpanBoundaryCodepoints(
590       span_begin, span_end, /*count_from_beginning=*/false);
591 
592   if (span.first + start_offset < span.second - end_offset) {
593     return {span.first + start_offset, span.second - end_offset};
594   } else {
595     return {span.first, span.first};
596   }
597 }
598 
SupportedCodepointsRatio(const TokenSpan & token_span,const std::vector<Token> & tokens) const599 float FeatureProcessor::SupportedCodepointsRatio(
600     const TokenSpan& token_span, const std::vector<Token>& tokens) const {
601   int num_supported = 0;
602   int num_total = 0;
603   for (int i = token_span.first; i < token_span.second; ++i) {
604     const UnicodeText value =
605         UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
606     for (auto codepoint : value) {
607       if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
608         ++num_supported;
609       }
610       ++num_total;
611     }
612   }
613   return static_cast<float>(num_supported) / static_cast<float>(num_total);
614 }
615 
StripBoundaryCodepoints(const std::string & value,std::string * buffer) const616 const std::string& FeatureProcessor::StripBoundaryCodepoints(
617     const std::string& value, std::string* buffer) const {
618   const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
619   const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
620   const CodepointSpan stripped_span =
621       StripBoundaryCodepoints(value_unicode, initial_span);
622 
623   if (initial_span != stripped_span) {
624     const UnicodeText stripped_token_value =
625         UnicodeText::Substring(value_unicode, stripped_span.first,
626                                stripped_span.second, /*do_copy=*/false);
627     *buffer = stripped_token_value.ToUTF8String();
628     return *buffer;
629   }
630   return value;
631 }
632 
CollectionToLabel(const std::string & collection) const633 int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
634   const auto it = collection_to_label_.find(collection);
635   if (it == collection_to_label_.end()) {
636     return options_->default_collection();
637   } else {
638     return it->second;
639   }
640 }
641 
LabelToCollection(int label) const642 std::string FeatureProcessor::LabelToCollection(int label) const {
643   if (label >= 0 && label < collection_to_label_.size()) {
644     return (*options_->collections())[label]->str();
645   } else {
646     return GetDefaultCollection();
647   }
648 }
649 
MakeLabelMaps()650 void FeatureProcessor::MakeLabelMaps() {
651   if (options_->collections() != nullptr) {
652     for (int i = 0; i < options_->collections()->size(); ++i) {
653       collection_to_label_[(*options_->collections())[i]->str()] = i;
654     }
655   }
656 
657   int selection_label_id = 0;
658   for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
659     for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
660       if (!options_->selection_reduced_output_space() ||
661           r + l <= options_->max_selection_span()) {
662         TokenSpan token_span{l, r};
663         selection_to_label_[token_span] = selection_label_id;
664         label_to_selection_.push_back(token_span);
665         ++selection_label_id;
666       }
667     }
668   }
669 }
670 
RetokenizeAndFindClick(const std::string & context,CodepointSpan input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const671 void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
672                                               CodepointSpan input_span,
673                                               bool only_use_line_with_click,
674                                               std::vector<Token>* tokens,
675                                               int* click_pos) const {
676   const UnicodeText context_unicode =
677       UTF8ToUnicodeText(context, /*do_copy=*/false);
678   RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
679                          tokens, click_pos);
680 }
681 
RetokenizeAndFindClick(const UnicodeText & context_unicode,CodepointSpan input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const682 void FeatureProcessor::RetokenizeAndFindClick(
683     const UnicodeText& context_unicode, CodepointSpan input_span,
684     bool only_use_line_with_click, std::vector<Token>* tokens,
685     int* click_pos) const {
686   TC3_CHECK(tokens != nullptr);
687 
688   if (options_->split_tokens_on_selection_boundaries()) {
689     internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
690   }
691 
692   if (only_use_line_with_click) {
693     StripTokensFromOtherLines(context_unicode, input_span, tokens);
694   }
695 
696   int local_click_pos;
697   if (click_pos == nullptr) {
698     click_pos = &local_click_pos;
699   }
700   *click_pos = FindCenterToken(input_span, *tokens);
701   if (*click_pos == kInvalidIndex) {
702     // If the default click method failed, let's try to do sub-token matching
703     // before we fail.
704     *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
705   }
706 }
707 
708 namespace internal {
709 
StripOrPadTokens(TokenSpan relative_click_span,int context_size,std::vector<Token> * tokens,int * click_pos)710 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
711                       std::vector<Token>* tokens, int* click_pos) {
712   int right_context_needed = relative_click_span.second + context_size;
713   if (*click_pos + right_context_needed + 1 >= tokens->size()) {
714     // Pad max the context size.
715     const int num_pad_tokens = std::min(
716         context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
717                                        tokens->size()));
718     std::vector<Token> pad_tokens(num_pad_tokens);
719     tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
720   } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
721     // Strip unused tokens.
722     auto it = tokens->begin();
723     std::advance(it, *click_pos + right_context_needed + 1);
724     tokens->erase(it, tokens->end());
725   }
726 
727   int left_context_needed = relative_click_span.first + context_size;
728   if (*click_pos < left_context_needed) {
729     // Pad max the context size.
730     const int num_pad_tokens =
731         std::min(context_size, left_context_needed - *click_pos);
732     std::vector<Token> pad_tokens(num_pad_tokens);
733     tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
734     *click_pos += num_pad_tokens;
735   } else if (*click_pos > left_context_needed) {
736     // Strip unused tokens.
737     auto it = tokens->begin();
738     std::advance(it, *click_pos - left_context_needed);
739     *click_pos -= it - tokens->begin();
740     tokens->erase(tokens->begin(), it);
741   }
742 }
743 
744 }  // namespace internal
745 
HasEnoughSupportedCodepoints(const std::vector<Token> & tokens,TokenSpan token_span) const746 bool FeatureProcessor::HasEnoughSupportedCodepoints(
747     const std::vector<Token>& tokens, TokenSpan token_span) const {
748   if (options_->min_supported_codepoint_ratio() > 0) {
749     const float supported_codepoint_ratio =
750         SupportedCodepointsRatio(token_span, tokens);
751     if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
752       TC3_VLOG(1) << "Not enough supported codepoints in the context: "
753                   << supported_codepoint_ratio;
754       return false;
755     }
756   }
757   return true;
758 }
759 
ExtractFeatures(const std::vector<Token> & tokens,TokenSpan token_span,CodepointSpan selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,int feature_vector_size,std::unique_ptr<CachedFeatures> * cached_features) const760 bool FeatureProcessor::ExtractFeatures(
761     const std::vector<Token>& tokens, TokenSpan token_span,
762     CodepointSpan selection_span_for_feature,
763     const EmbeddingExecutor* embedding_executor,
764     EmbeddingCache* embedding_cache, int feature_vector_size,
765     std::unique_ptr<CachedFeatures>* cached_features) const {
766   std::unique_ptr<std::vector<float>> features(new std::vector<float>());
767   features->reserve(feature_vector_size * TokenSpanSize(token_span));
768   for (int i = token_span.first; i < token_span.second; ++i) {
769     if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
770                                       embedding_executor, embedding_cache,
771                                       features.get())) {
772       TC3_LOG(ERROR) << "Could not get token features.";
773       return false;
774     }
775   }
776 
777   std::unique_ptr<std::vector<float>> padding_features(
778       new std::vector<float>());
779   padding_features->reserve(feature_vector_size);
780   if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
781                                     embedding_executor, embedding_cache,
782                                     padding_features.get())) {
783     TC3_LOG(ERROR) << "Count not get padding token features.";
784     return false;
785   }
786 
787   *cached_features = CachedFeatures::Create(token_span, std::move(features),
788                                             std::move(padding_features),
789                                             options_, feature_vector_size);
790   if (!*cached_features) {
791     TC3_LOG(ERROR) << "Cound not create cached features.";
792     return false;
793   }
794 
795   return true;
796 }
797 
AppendTokenFeaturesWithCache(const Token & token,CodepointSpan selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,std::vector<float> * output_features) const798 bool FeatureProcessor::AppendTokenFeaturesWithCache(
799     const Token& token, CodepointSpan selection_span_for_feature,
800     const EmbeddingExecutor* embedding_executor,
801     EmbeddingCache* embedding_cache,
802     std::vector<float>* output_features) const {
803   // Look for the embedded features for the token in the cache, if there is one.
804   if (embedding_cache) {
805     const auto it = embedding_cache->find({token.start, token.end});
806     if (it != embedding_cache->end()) {
807       // The embedded features were found in the cache, extract only the dense
808       // features.
809       std::vector<float> dense_features;
810       if (!feature_extractor_.Extract(
811               token, token.IsContainedInSpan(selection_span_for_feature),
812               /*sparse_features=*/nullptr, &dense_features)) {
813         TC3_LOG(ERROR) << "Could not extract token's dense features.";
814         return false;
815       }
816 
817       // Append both embedded and dense features to the output and return.
818       output_features->insert(output_features->end(), it->second.begin(),
819                               it->second.end());
820       output_features->insert(output_features->end(), dense_features.begin(),
821                               dense_features.end());
822       return true;
823     }
824   }
825 
826   // Extract the sparse and dense features.
827   std::vector<int> sparse_features;
828   std::vector<float> dense_features;
829   if (!feature_extractor_.Extract(
830           token, token.IsContainedInSpan(selection_span_for_feature),
831           &sparse_features, &dense_features)) {
832     TC3_LOG(ERROR) << "Could not extract token's features.";
833     return false;
834   }
835 
836   // Embed the sparse features, appending them directly to the output.
837   const int embedding_size = GetOptions()->embedding_size();
838   output_features->resize(output_features->size() + embedding_size);
839   float* output_features_end =
840       output_features->data() + output_features->size();
841   if (!embedding_executor->AddEmbedding(
842           TensorView<int>(sparse_features.data(),
843                           {static_cast<int>(sparse_features.size())}),
844           /*dest=*/output_features_end - embedding_size,
845           /*dest_size=*/embedding_size)) {
846     TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
847     return false;
848   }
849 
850   // If there is a cache, the embedded features for the token were not in it,
851   // so insert them.
852   if (embedding_cache) {
853     (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
854         output_features_end - embedding_size, output_features_end);
855   }
856 
857   // Append the dense features to the output.
858   output_features->insert(output_features->end(), dense_features.begin(),
859                           dense_features.end());
860   return true;
861 }
862 
863 }  // namespace libtextclassifier3
864