1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "smartselect/feature-processor.h"
18 
19 #include <iterator>
20 #include <set>
21 #include <vector>
22 
23 #include "smartselect/text-classification-model.pb.h"
24 #include "util/base/logging.h"
25 #include "util/strings/utf8.h"
26 #include "util/utf8/unicodetext.h"
27 #include "unicode/brkiter.h"
28 #include "unicode/errorcode.h"
29 #include "unicode/uchar.h"
30 
31 namespace libtextclassifier {
32 
33 namespace internal {
34 
BuildTokenFeatureExtractorOptions(const FeatureProcessorOptions & options)35 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
36     const FeatureProcessorOptions& options) {
37   TokenFeatureExtractorOptions extractor_options;
38 
39   extractor_options.num_buckets = options.num_buckets();
40   for (int order : options.chargram_orders()) {
41     extractor_options.chargram_orders.push_back(order);
42   }
43   extractor_options.max_word_length = options.max_word_length();
44   extractor_options.extract_case_feature = options.extract_case_feature();
45   extractor_options.unicode_aware_features = options.unicode_aware_features();
46   extractor_options.extract_selection_mask_feature =
47       options.extract_selection_mask_feature();
48   for (int i = 0; i < options.regexp_feature_size(); ++i) {
49     extractor_options.regexp_features.push_back(options.regexp_feature(i));
50   }
51   extractor_options.remap_digits = options.remap_digits();
52   extractor_options.lowercase_tokens = options.lowercase_tokens();
53 
54   return extractor_options;
55 }
56 
ParseSerializedOptions(const std::string & serialized_options)57 FeatureProcessorOptions ParseSerializedOptions(
58     const std::string& serialized_options) {
59   FeatureProcessorOptions options;
60   options.ParseFromString(serialized_options);
61   return options;
62 }
63 
SplitTokensOnSelectionBoundaries(CodepointSpan selection,std::vector<Token> * tokens)64 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
65                                       std::vector<Token>* tokens) {
66   for (auto it = tokens->begin(); it != tokens->end(); ++it) {
67     const UnicodeText token_word =
68         UTF8ToUnicodeText(it->value, /*do_copy=*/false);
69 
70     auto last_start = token_word.begin();
71     int last_start_index = it->start;
72     std::vector<UnicodeText::const_iterator> split_points;
73 
74     // Selection start split point.
75     if (selection.first > it->start && selection.first < it->end) {
76       std::advance(last_start, selection.first - last_start_index);
77       split_points.push_back(last_start);
78       last_start_index = selection.first;
79     }
80 
81     // Selection end split point.
82     if (selection.second > it->start && selection.second < it->end) {
83       std::advance(last_start, selection.second - last_start_index);
84       split_points.push_back(last_start);
85     }
86 
87     if (!split_points.empty()) {
88       // Add a final split for the rest of the token unless it's been all
89       // consumed already.
90       if (split_points.back() != token_word.end()) {
91         split_points.push_back(token_word.end());
92       }
93 
94       std::vector<Token> replacement_tokens;
95       last_start = token_word.begin();
96       int current_pos = it->start;
97       for (const auto& split_point : split_points) {
98         Token new_token(token_word.UTF8Substring(last_start, split_point),
99                         current_pos,
100                         current_pos + std::distance(last_start, split_point));
101 
102         last_start = split_point;
103         current_pos = new_token.end;
104 
105         replacement_tokens.push_back(new_token);
106       }
107 
108       it = tokens->erase(it);
109       it = tokens->insert(it, replacement_tokens.begin(),
110                           replacement_tokens.end());
111       std::advance(it, replacement_tokens.size() - 1);
112     }
113   }
114 }
115 
FindSubstrings(const UnicodeText & t,const std::set<char32> & codepoints,std::vector<UnicodeTextRange> * ranges)116 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
117                     std::vector<UnicodeTextRange>* ranges) {
118   UnicodeText::const_iterator start = t.begin();
119   UnicodeText::const_iterator curr = start;
120   UnicodeText::const_iterator end = t.end();
121   for (; curr != end; ++curr) {
122     if (codepoints.find(*curr) != codepoints.end()) {
123       if (start != curr) {
124         ranges->push_back(std::make_pair(start, curr));
125       }
126       start = curr;
127       ++start;
128     }
129   }
130   if (start != end) {
131     ranges->push_back(std::make_pair(start, end));
132   }
133 }
134 
StripTokensFromOtherLines(const std::string & context,CodepointSpan span,std::vector<Token> * tokens)135 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
136                                std::vector<Token>* tokens) {
137   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
138                                                         /*do_copy=*/false);
139   std::vector<UnicodeTextRange> lines;
140   std::set<char32> codepoints;
141   codepoints.insert('\n');
142   codepoints.insert('|');
143   internal::FindSubstrings(context_unicode, codepoints, &lines);
144 
145   auto span_start = context_unicode.begin();
146   if (span.first > 0) {
147     std::advance(span_start, span.first);
148   }
149   auto span_end = context_unicode.begin();
150   if (span.second > 0) {
151     std::advance(span_end, span.second);
152   }
153   for (const UnicodeTextRange& line : lines) {
154     // Find the line that completely contains the span.
155     if (line.first <= span_start && line.second >= span_end) {
156       const CodepointIndex last_line_begin_index =
157           std::distance(context_unicode.begin(), line.first);
158       const CodepointIndex last_line_end_index =
159           last_line_begin_index + std::distance(line.first, line.second);
160 
161       for (auto token = tokens->begin(); token != tokens->end();) {
162         if (token->start >= last_line_begin_index &&
163             token->end <= last_line_end_index) {
164           ++token;
165         } else {
166           token = tokens->erase(token);
167         }
168       }
169     }
170   }
171 }
172 
173 }  // namespace internal
174 
GetDefaultCollection() const175 std::string FeatureProcessor::GetDefaultCollection() const {
176   if (options_.default_collection() >= options_.collections_size()) {
177     TC_LOG(ERROR) << "No collections specified. Returning empty string.";
178     return "";
179   }
180   return options_.collections(options_.default_collection());
181 }
182 
Tokenize(const std::string & utf8_text) const183 std::vector<Token> FeatureProcessor::Tokenize(
184     const std::string& utf8_text) const {
185   if (options_.tokenization_type() ==
186       libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) {
187     return tokenizer_.Tokenize(utf8_text);
188   } else if (options_.tokenization_type() ==
189                  libtextclassifier::FeatureProcessorOptions::ICU ||
190              options_.tokenization_type() ==
191                  libtextclassifier::FeatureProcessorOptions::MIXED) {
192     std::vector<Token> result;
193     if (!ICUTokenize(utf8_text, &result)) {
194       return {};
195     }
196     if (options_.tokenization_type() ==
197         libtextclassifier::FeatureProcessorOptions::MIXED) {
198       InternalRetokenize(utf8_text, &result);
199     }
200     return result;
201   } else {
202     TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
203                      "internal.";
204     return tokenizer_.Tokenize(utf8_text);
205   }
206 }
207 
LabelToSpan(const int label,const VectorSpan<Token> & tokens,std::pair<CodepointIndex,CodepointIndex> * span) const208 bool FeatureProcessor::LabelToSpan(
209     const int label, const VectorSpan<Token>& tokens,
210     std::pair<CodepointIndex, CodepointIndex>* span) const {
211   if (tokens.size() != GetNumContextTokens()) {
212     return false;
213   }
214 
215   TokenSpan token_span;
216   if (!LabelToTokenSpan(label, &token_span)) {
217     return false;
218   }
219 
220   const int result_begin_token = token_span.first;
221   const int result_begin_codepoint =
222       tokens[options_.context_size() - result_begin_token].start;
223   const int result_end_token = token_span.second;
224   const int result_end_codepoint =
225       tokens[options_.context_size() + result_end_token].end;
226 
227   if (result_begin_codepoint == kInvalidIndex ||
228       result_end_codepoint == kInvalidIndex) {
229     *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
230   } else {
231     *span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
232   }
233   return true;
234 }
235 
LabelToTokenSpan(const int label,TokenSpan * token_span) const236 bool FeatureProcessor::LabelToTokenSpan(const int label,
237                                         TokenSpan* token_span) const {
238   if (label >= 0 && label < label_to_selection_.size()) {
239     *token_span = label_to_selection_[label];
240     return true;
241   } else {
242     return false;
243   }
244 }
245 
SpanToLabel(const std::pair<CodepointIndex,CodepointIndex> & span,const std::vector<Token> & tokens,int * label) const246 bool FeatureProcessor::SpanToLabel(
247     const std::pair<CodepointIndex, CodepointIndex>& span,
248     const std::vector<Token>& tokens, int* label) const {
249   if (tokens.size() != GetNumContextTokens()) {
250     return false;
251   }
252 
253   const int click_position =
254       options_.context_size();  // Click is always in the middle.
255   const int padding = options_.context_size() - options_.max_selection_span();
256 
257   int span_left = 0;
258   for (int i = click_position - 1; i >= padding; i--) {
259     if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
260       ++span_left;
261     } else {
262       break;
263     }
264   }
265 
266   int span_right = 0;
267   for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
268     if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
269       ++span_right;
270     } else {
271       break;
272     }
273   }
274 
275   // Check that the spanned tokens cover the whole span.
276   bool tokens_match_span;
277   if (options_.snap_label_span_boundaries_to_containing_tokens()) {
278     tokens_match_span =
279         tokens[click_position - span_left].start <= span.first &&
280         tokens[click_position + span_right].end >= span.second;
281   } else {
282     tokens_match_span =
283         tokens[click_position - span_left].start == span.first &&
284         tokens[click_position + span_right].end == span.second;
285   }
286 
287   if (tokens_match_span) {
288     *label = TokenSpanToLabel({span_left, span_right});
289   } else {
290     *label = kInvalidLabel;
291   }
292 
293   return true;
294 }
295 
TokenSpanToLabel(const TokenSpan & span) const296 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
297   auto it = selection_to_label_.find(span);
298   if (it != selection_to_label_.end()) {
299     return it->second;
300   } else {
301     return kInvalidLabel;
302   }
303 }
304 
CodepointSpanToTokenSpan(const std::vector<Token> & selectable_tokens,CodepointSpan codepoint_span)305 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
306                                    CodepointSpan codepoint_span) {
307   const int codepoint_start = std::get<0>(codepoint_span);
308   const int codepoint_end = std::get<1>(codepoint_span);
309 
310   TokenIndex start_token = kInvalidIndex;
311   TokenIndex end_token = kInvalidIndex;
312   for (int i = 0; i < selectable_tokens.size(); ++i) {
313     if (codepoint_start <= selectable_tokens[i].start &&
314         codepoint_end >= selectable_tokens[i].end &&
315         !selectable_tokens[i].is_padding) {
316       if (start_token == kInvalidIndex) {
317         start_token = i;
318       }
319       end_token = i + 1;
320     }
321   }
322   return {start_token, end_token};
323 }
324 
TokenSpanToCodepointSpan(const std::vector<Token> & selectable_tokens,TokenSpan token_span)325 CodepointSpan TokenSpanToCodepointSpan(
326     const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
327   return {selectable_tokens[token_span.first].start,
328           selectable_tokens[token_span.second - 1].end};
329 }
330 
331 namespace {
332 
333 // Finds a single token that completely contains the given span.
FindTokenThatContainsSpan(const std::vector<Token> & selectable_tokens,CodepointSpan codepoint_span)334 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
335                               CodepointSpan codepoint_span) {
336   const int codepoint_start = std::get<0>(codepoint_span);
337   const int codepoint_end = std::get<1>(codepoint_span);
338 
339   for (int i = 0; i < selectable_tokens.size(); ++i) {
340     if (codepoint_start >= selectable_tokens[i].start &&
341         codepoint_end <= selectable_tokens[i].end) {
342       return i;
343     }
344   }
345   return kInvalidIndex;
346 }
347 
348 }  // namespace
349 
350 namespace internal {
351 
CenterTokenFromClick(CodepointSpan span,const std::vector<Token> & selectable_tokens)352 int CenterTokenFromClick(CodepointSpan span,
353                          const std::vector<Token>& selectable_tokens) {
354   int range_begin;
355   int range_end;
356   std::tie(range_begin, range_end) =
357       CodepointSpanToTokenSpan(selectable_tokens, span);
358 
359   // If no exact match was found, try finding a token that completely contains
360   // the click span. This is useful e.g. when Android builds the selection
361   // using ICU tokenization, and ends up with only a portion of our space-
362   // separated token. E.g. for "(857)" Android would select "857".
363   if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
364     int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
365     if (token_index != kInvalidIndex) {
366       range_begin = token_index;
367       range_end = token_index + 1;
368     }
369   }
370 
371   // We only allow clicks that are exactly 1 selectable token.
372   if (range_end - range_begin == 1) {
373     return range_begin;
374   } else {
375     return kInvalidIndex;
376   }
377 }
378 
CenterTokenFromMiddleOfSelection(CodepointSpan span,const std::vector<Token> & selectable_tokens)379 int CenterTokenFromMiddleOfSelection(
380     CodepointSpan span, const std::vector<Token>& selectable_tokens) {
381   int range_begin;
382   int range_end;
383   std::tie(range_begin, range_end) =
384       CodepointSpanToTokenSpan(selectable_tokens, span);
385 
386   // Center the clicked token in the selection range.
387   if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
388     return (range_begin + range_end - 1) / 2;
389   } else {
390     return kInvalidIndex;
391   }
392 }
393 
394 }  // namespace internal
395 
FindCenterToken(CodepointSpan span,const std::vector<Token> & tokens) const396 int FeatureProcessor::FindCenterToken(CodepointSpan span,
397                                       const std::vector<Token>& tokens) const {
398   if (options_.center_token_selection_method() ==
399       FeatureProcessorOptions::CENTER_TOKEN_FROM_CLICK) {
400     return internal::CenterTokenFromClick(span, tokens);
401   } else if (options_.center_token_selection_method() ==
402              FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION) {
403     return internal::CenterTokenFromMiddleOfSelection(span, tokens);
404   } else if (options_.center_token_selection_method() ==
405              FeatureProcessorOptions::DEFAULT_CENTER_TOKEN_METHOD) {
406     // TODO(zilka): Remove once we have new models on the device.
407     // It uses the fact that sharing model use
408     // split_tokens_on_selection_boundaries and selection not. So depending on
409     // this we select the right way of finding the click location.
410     if (!options_.split_tokens_on_selection_boundaries()) {
411       // SmartSelection model.
412       return internal::CenterTokenFromClick(span, tokens);
413     } else {
414       // SmartSharing model.
415       return internal::CenterTokenFromMiddleOfSelection(span, tokens);
416     }
417   } else {
418     TC_LOG(ERROR) << "Invalid center token selection method.";
419     return kInvalidIndex;
420   }
421 }
422 
SelectionLabelSpans(const VectorSpan<Token> tokens,std::vector<CodepointSpan> * selection_label_spans) const423 bool FeatureProcessor::SelectionLabelSpans(
424     const VectorSpan<Token> tokens,
425     std::vector<CodepointSpan>* selection_label_spans) const {
426   for (int i = 0; i < label_to_selection_.size(); ++i) {
427     CodepointSpan span;
428     if (!LabelToSpan(i, tokens, &span)) {
429       TC_LOG(ERROR) << "Could not convert label to span: " << i;
430       return false;
431     }
432     selection_label_spans->push_back(span);
433   }
434   return true;
435 }
436 
PrepareCodepointRanges(const std::vector<FeatureProcessorOptions::CodepointRange> & codepoint_ranges,std::vector<CodepointRange> * prepared_codepoint_ranges)437 void FeatureProcessor::PrepareCodepointRanges(
438     const std::vector<FeatureProcessorOptions::CodepointRange>&
439         codepoint_ranges,
440     std::vector<CodepointRange>* prepared_codepoint_ranges) {
441   prepared_codepoint_ranges->clear();
442   prepared_codepoint_ranges->reserve(codepoint_ranges.size());
443   for (const FeatureProcessorOptions::CodepointRange& range :
444        codepoint_ranges) {
445     prepared_codepoint_ranges->push_back(
446         CodepointRange(range.start(), range.end()));
447   }
448 
449   std::sort(prepared_codepoint_ranges->begin(),
450             prepared_codepoint_ranges->end(),
451             [](const CodepointRange& a, const CodepointRange& b) {
452               return a.start < b.start;
453             });
454 }
455 
SupportedCodepointsRatio(int click_pos,const std::vector<Token> & tokens) const456 float FeatureProcessor::SupportedCodepointsRatio(
457     int click_pos, const std::vector<Token>& tokens) const {
458   int num_supported = 0;
459   int num_total = 0;
460   for (int i = click_pos - options_.context_size();
461        i <= click_pos + options_.context_size(); ++i) {
462     const bool is_valid_token = i >= 0 && i < tokens.size();
463     if (is_valid_token) {
464       const UnicodeText value =
465           UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
466       for (auto codepoint : value) {
467         if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
468           ++num_supported;
469         }
470         ++num_total;
471       }
472     }
473   }
474   return static_cast<float>(num_supported) / static_cast<float>(num_total);
475 }
476 
IsCodepointInRanges(int codepoint,const std::vector<CodepointRange> & codepoint_ranges) const477 bool FeatureProcessor::IsCodepointInRanges(
478     int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const {
479   auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(),
480                              codepoint,
481                              [](const CodepointRange& range, int codepoint) {
482                                // This function compares range with the
483                                // codepoint for the purpose of finding the first
484                                // greater or equal range. Because of the use of
485                                // std::lower_bound it needs to return true when
486                                // range < codepoint; the first time it will
487                                // return false the lower bound is found and
488                                // returned.
489                                //
490                                // It might seem weird that the condition is
491                                // range.end <= codepoint here but when codepoint
492                                // == range.end it means it's actually just
493                                // outside of the range, thus the range is less
494                                // than the codepoint.
495                                return range.end <= codepoint;
496                              });
497   if (it != codepoint_ranges.end() && it->start <= codepoint &&
498       it->end > codepoint) {
499     return true;
500   } else {
501     return false;
502   }
503 }
504 
CollectionToLabel(const std::string & collection) const505 int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
506   const auto it = collection_to_label_.find(collection);
507   if (it == collection_to_label_.end()) {
508     return options_.default_collection();
509   } else {
510     return it->second;
511   }
512 }
513 
LabelToCollection(int label) const514 std::string FeatureProcessor::LabelToCollection(int label) const {
515   if (label >= 0 && label < collection_to_label_.size()) {
516     return options_.collections(label);
517   } else {
518     return GetDefaultCollection();
519   }
520 }
521 
MakeLabelMaps()522 void FeatureProcessor::MakeLabelMaps() {
523   for (int i = 0; i < options_.collections().size(); ++i) {
524     collection_to_label_[options_.collections(i)] = i;
525   }
526 
527   int selection_label_id = 0;
528   for (int l = 0; l < (options_.max_selection_span() + 1); ++l) {
529     for (int r = 0; r < (options_.max_selection_span() + 1); ++r) {
530       if (!options_.selection_reduced_output_space() ||
531           r + l <= options_.max_selection_span()) {
532         TokenSpan token_span{l, r};
533         selection_to_label_[token_span] = selection_label_id;
534         label_to_selection_.push_back(token_span);
535         ++selection_label_id;
536       }
537     }
538   }
539 }
540 
TokenizeAndFindClick(const std::string & context,CodepointSpan input_span,std::vector<Token> * tokens,int * click_pos) const541 void FeatureProcessor::TokenizeAndFindClick(const std::string& context,
542                                             CodepointSpan input_span,
543                                             std::vector<Token>* tokens,
544                                             int* click_pos) const {
545   TC_CHECK(tokens != nullptr);
546   *tokens = Tokenize(context);
547 
548   if (options_.split_tokens_on_selection_boundaries()) {
549     internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
550   }
551 
552   if (options_.only_use_line_with_click()) {
553     internal::StripTokensFromOtherLines(context, input_span, tokens);
554   }
555 
556   int local_click_pos;
557   if (click_pos == nullptr) {
558     click_pos = &local_click_pos;
559   }
560   *click_pos = FindCenterToken(input_span, *tokens);
561 }
562 
563 namespace internal {
564 
StripOrPadTokens(TokenSpan relative_click_span,int context_size,std::vector<Token> * tokens,int * click_pos)565 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
566                       std::vector<Token>* tokens, int* click_pos) {
567   int right_context_needed = relative_click_span.second + context_size;
568   if (*click_pos + right_context_needed + 1 >= tokens->size()) {
569     // Pad max the context size.
570     const int num_pad_tokens = std::min(
571         context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
572                                        tokens->size()));
573     std::vector<Token> pad_tokens(num_pad_tokens);
574     tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
575   } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
576     // Strip unused tokens.
577     auto it = tokens->begin();
578     std::advance(it, *click_pos + right_context_needed + 1);
579     tokens->erase(it, tokens->end());
580   }
581 
582   int left_context_needed = relative_click_span.first + context_size;
583   if (*click_pos < left_context_needed) {
584     // Pad max the context size.
585     const int num_pad_tokens =
586         std::min(context_size, left_context_needed - *click_pos);
587     std::vector<Token> pad_tokens(num_pad_tokens);
588     tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
589     *click_pos += num_pad_tokens;
590   } else if (*click_pos > left_context_needed) {
591     // Strip unused tokens.
592     auto it = tokens->begin();
593     std::advance(it, *click_pos - left_context_needed);
594     *click_pos -= it - tokens->begin();
595     tokens->erase(tokens->begin(), it);
596   }
597 }
598 
599 }  // namespace internal
600 
ExtractFeatures(const std::string & context,CodepointSpan input_span,TokenSpan relative_click_span,const FeatureVectorFn & feature_vector_fn,int feature_vector_size,std::vector<Token> * tokens,int * click_pos,std::unique_ptr<CachedFeatures> * cached_features) const601 bool FeatureProcessor::ExtractFeatures(
602     const std::string& context, CodepointSpan input_span,
603     TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn,
604     int feature_vector_size, std::vector<Token>* tokens, int* click_pos,
605     std::unique_ptr<CachedFeatures>* cached_features) const {
606   TokenizeAndFindClick(context, input_span, tokens, click_pos);
607 
608   // If the default click method failed, let's try to do sub-token matching
609   // before we fail.
610   if (*click_pos == kInvalidIndex) {
611     *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
612     if (*click_pos == kInvalidIndex) {
613       return false;
614     }
615   }
616 
617   internal::StripOrPadTokens(relative_click_span, options_.context_size(),
618                              tokens, click_pos);
619 
620   if (options_.min_supported_codepoint_ratio() > 0) {
621     const float supported_codepoint_ratio =
622         SupportedCodepointsRatio(*click_pos, *tokens);
623     if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
624       TC_LOG(INFO) << "Not enough supported codepoints in the context: "
625                    << supported_codepoint_ratio;
626       return false;
627     }
628   }
629 
630   std::vector<std::vector<int>> sparse_features(tokens->size());
631   std::vector<std::vector<float>> dense_features(tokens->size());
632   for (int i = 0; i < tokens->size(); ++i) {
633     const Token& token = (*tokens)[i];
634     if (!feature_extractor_.Extract(token, token.IsContainedInSpan(input_span),
635                                     &(sparse_features[i]),
636                                     &(dense_features[i]))) {
637       TC_LOG(ERROR) << "Could not extract token's features: " << token;
638       return false;
639     }
640   }
641 
642   cached_features->reset(new CachedFeatures(
643       *tokens, options_.context_size(), sparse_features, dense_features,
644       feature_vector_fn, feature_vector_size));
645 
646   if (*cached_features == nullptr) {
647     return false;
648   }
649 
650   if (options_.feature_version() == 0) {
651     (*cached_features)
652         ->SetV0FeatureMode(feature_vector_size -
653                            feature_extractor_.DenseFeaturesCount());
654   }
655 
656   return true;
657 }
658 
ICUTokenize(const std::string & context,std::vector<Token> * result) const659 bool FeatureProcessor::ICUTokenize(const std::string& context,
660                                    std::vector<Token>* result) const {
661   icu::ErrorCode status;
662   icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
663   std::unique_ptr<icu::BreakIterator> break_iterator(
664       icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
665   if (!status.isSuccess()) {
666     TC_LOG(ERROR) << "Break iterator did not initialize properly: "
667                   << status.errorName();
668     return false;
669   }
670 
671   break_iterator->setText(unicode_text);
672 
673   size_t last_break_index = 0;
674   size_t break_index = 0;
675   size_t last_unicode_index = 0;
676   size_t unicode_index = 0;
677   while ((break_index = break_iterator->next()) != icu::BreakIterator::DONE) {
678     icu::UnicodeString token(unicode_text, last_break_index,
679                              break_index - last_break_index);
680     int token_length = token.countChar32();
681     unicode_index = last_unicode_index + token_length;
682 
683     std::string token_utf8;
684     token.toUTF8String(token_utf8);
685 
686     bool is_whitespace = true;
687     for (int i = 0; i < token.length(); i++) {
688       if (!u_isWhitespace(token.char32At(i))) {
689         is_whitespace = false;
690       }
691     }
692 
693     if (!is_whitespace || options_.icu_preserve_whitespace_tokens()) {
694       result->push_back(Token(token_utf8, last_unicode_index, unicode_index));
695     }
696 
697     last_break_index = break_index;
698     last_unicode_index = unicode_index;
699   }
700 
701   return true;
702 }
703 
InternalRetokenize(const std::string & context,std::vector<Token> * tokens) const704 void FeatureProcessor::InternalRetokenize(const std::string& context,
705                                           std::vector<Token>* tokens) const {
706   const UnicodeText unicode_text =
707       UTF8ToUnicodeText(context, /*do_copy=*/false);
708 
709   std::vector<Token> result;
710   CodepointSpan span(-1, -1);
711   for (Token& token : *tokens) {
712     const UnicodeText unicode_token_value =
713         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
714     bool should_retokenize = true;
715     for (const int codepoint : unicode_token_value) {
716       if (!IsCodepointInRanges(codepoint,
717                                internal_tokenizer_codepoint_ranges_)) {
718         should_retokenize = false;
719         break;
720       }
721     }
722 
723     if (should_retokenize) {
724       if (span.first < 0) {
725         span.first = token.start;
726       }
727       span.second = token.end;
728     } else {
729       TokenizeSubstring(unicode_text, span, &result);
730       span.first = -1;
731       result.emplace_back(std::move(token));
732     }
733   }
734   TokenizeSubstring(unicode_text, span, &result);
735 
736   *tokens = std::move(result);
737 }
738 
TokenizeSubstring(const UnicodeText & unicode_text,CodepointSpan span,std::vector<Token> * result) const739 void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text,
740                                          CodepointSpan span,
741                                          std::vector<Token>* result) const {
742   if (span.first < 0) {
743     // There is no span to tokenize.
744     return;
745   }
746 
747   // Extract the substring.
748   UnicodeText::const_iterator it_begin = unicode_text.begin();
749   for (int i = 0; i < span.first; ++i) {
750     ++it_begin;
751   }
752   UnicodeText::const_iterator it_end = unicode_text.begin();
753   for (int i = 0; i < span.second; ++i) {
754     ++it_end;
755   }
756   const std::string text = unicode_text.UTF8Substring(it_begin, it_end);
757 
758   // Run the tokenizer and update the token bounds to reflect the offset of the
759   // substring.
760   std::vector<Token> tokens = tokenizer_.Tokenize(text);
761   for (Token& token : tokens) {
762     token.start += span.first;
763     token.end += span.first;
764     result->emplace_back(std::move(token));
765   }
766 }
767 
768 }  // namespace libtextclassifier
769