1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Feature processing for FFModel (feed-forward SmartSelection model).
18 
19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
20 #define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
21 
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <vector>
27 
28 #include "annotator/cached-features.h"
29 #include "annotator/model_generated.h"
30 #include "annotator/types.h"
31 #include "utils/base/integral_types.h"
32 #include "utils/base/logging.h"
33 #include "utils/token-feature-extractor.h"
34 #include "utils/tokenizer.h"
35 #include "utils/utf8/unicodetext.h"
36 #include "utils/utf8/unilib.h"
37 
38 namespace libtextclassifier3 {
39 
40 constexpr int kInvalidLabel = -1;
41 
42 namespace internal {
43 
44 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
45                          const UniLib* unilib);
46 
47 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
48     const FeatureProcessorOptions* options);
49 
50 // Splits tokens that contain the selection boundary inside them.
51 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
52 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
53                                       std::vector<Token>* tokens);
54 
55 // Returns the index of token that corresponds to the codepoint span.
56 int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
57 
58 // Returns the index of token that corresponds to the middle of the  codepoint
59 // span.
60 int CenterTokenFromMiddleOfSelection(
61     CodepointSpan span, const std::vector<Token>& selectable_tokens);
62 
63 // Strips the tokens from the tokens vector that are not used for feature
64 // extraction because they are out of scope, or pads them so that there is
65 // enough tokens in the required context_size for all inferences with a click
66 // in relative_click_span.
67 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
68                       std::vector<Token>* tokens, int* click_pos);
69 
70 }  // namespace internal
71 
72 // Converts a codepoint span to a token span in the given list of tokens.
73 // If snap_boundaries_to_containing_tokens is set to true, it is enough for a
74 // token to overlap with the codepoint range to be considered part of it.
75 // Otherwise it must be fully included in the range.
76 TokenSpan CodepointSpanToTokenSpan(
77     const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
78     bool snap_boundaries_to_containing_tokens = false);
79 
80 // Converts a token span to a codepoint span in the given list of tokens.
81 CodepointSpan TokenSpanToCodepointSpan(
82     const std::vector<Token>& selectable_tokens, TokenSpan token_span);
83 
84 // Takes care of preparing features for the span prediction model.
85 class FeatureProcessor {
86  public:
87   // A cache mapping codepoint spans to embedded tokens features. An instance
88   // can be provided to multiple calls to ExtractFeatures() operating on the
89   // same context (the same codepoint spans corresponding to the same tokens),
90   // as an optimization. Note that the tokenizations do not have to be
91   // identical.
92   typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
93 
FeatureProcessor(const FeatureProcessorOptions * options,const UniLib * unilib)94   FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib)
95       : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
96                            *unilib),
97         options_(options),
98         tokenizer_(internal::BuildTokenizer(options, unilib)) {
99     MakeLabelMaps();
100     if (options->supported_codepoint_ranges() != nullptr) {
101       SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
102                            options->supported_codepoint_ranges()->end()},
103                           &supported_codepoint_ranges_);
104     }
105     PrepareIgnoredSpanBoundaryCodepoints();
106   }
107 
108   // Tokenizes the input string using the selected tokenization method.
109   std::vector<Token> Tokenize(const std::string& text) const;
110 
111   // Same as above but takes UnicodeText.
112   std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
113 
114   // Converts a label into a token span.
115   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
116 
117   // Gets the total number of selection labels.
GetSelectionLabelCount()118   int GetSelectionLabelCount() const { return label_to_selection_.size(); }
119 
120   // Gets the string value for given collection label.
121   std::string LabelToCollection(int label) const;
122 
123   // Gets the total number of collections of the model.
NumCollections()124   int NumCollections() const { return collection_to_label_.size(); }
125 
126   // Gets the name of the default collection.
127   std::string GetDefaultCollection() const;
128 
GetOptions()129   const FeatureProcessorOptions* GetOptions() const { return options_; }
130 
131   // Retokenizes the context and input span, and finds the click position.
132   // Depending on the options, might modify tokens (split them or remove them).
133   void RetokenizeAndFindClick(const std::string& context,
134                               CodepointSpan input_span,
135                               bool only_use_line_with_click,
136                               std::vector<Token>* tokens, int* click_pos) const;
137 
138   // Same as above but takes UnicodeText.
139   void RetokenizeAndFindClick(const UnicodeText& context_unicode,
140                               CodepointSpan input_span,
141                               bool only_use_line_with_click,
142                               std::vector<Token>* tokens, int* click_pos) const;
143 
144   // Returns true if the token span has enough supported codepoints (as defined
145   // in the model config) or not and model should not run.
146   bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
147                                     TokenSpan token_span) const;
148 
149   // Extracts features as a CachedFeatures object that can be used for repeated
150   // inference over token spans in the given context.
151   bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
152                        CodepointSpan selection_span_for_feature,
153                        const EmbeddingExecutor* embedding_executor,
154                        EmbeddingCache* embedding_cache, int feature_vector_size,
155                        std::unique_ptr<CachedFeatures>* cached_features) const;
156 
157   // Fills selection_label_spans with CodepointSpans that correspond to the
158   // selection labels. The CodepointSpans are based on the codepoint ranges of
159   // given tokens.
160   bool SelectionLabelSpans(
161       VectorSpan<Token> tokens,
162       std::vector<CodepointSpan>* selection_label_spans) const;
163 
DenseFeaturesCount()164   int DenseFeaturesCount() const {
165     return feature_extractor_.DenseFeaturesCount();
166   }
167 
EmbeddingSize()168   int EmbeddingSize() const { return options_->embedding_size(); }
169 
170   // Splits context to several segments.
171   std::vector<UnicodeTextRange> SplitContext(
172       const UnicodeText& context_unicode) const;
173 
174   // Strips boundary codepoints from the span in context and returns the new
175   // start and end indices. If the span comprises entirely of boundary
176   // codepoints, the first index of span is returned for both indices.
177   CodepointSpan StripBoundaryCodepoints(const std::string& context,
178                                         CodepointSpan span) const;
179 
180   // Same as above but takes UnicodeText.
181   CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
182                                         CodepointSpan span) const;
183 
184   // Same as above but takes a pair of iterators for the span, for efficiency.
185   CodepointSpan StripBoundaryCodepoints(
186       const UnicodeText::const_iterator& span_begin,
187       const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
188 
189   // Same as above, but takes an optional buffer for saving the modified value.
190   // As an optimization, returns pointer to 'value' if nothing was stripped, or
191   // pointer to 'buffer' if something was stripped.
192   const std::string& StripBoundaryCodepoints(const std::string& value,
193                                              std::string* buffer) const;
194 
195  protected:
196   // Returns the class id corresponding to the given string collection
197   // identifier. There is a catch-all class id that the function returns for
198   // unknown collections.
199   int CollectionToLabel(const std::string& collection) const;
200 
201   // Prepares mapping from collection names to labels.
202   void MakeLabelMaps();
203 
204   // Gets the number of spannable tokens for the model.
205   //
206   // Spannable tokens are those tokens of context, which the model predicts
207   // selection spans over (i.e., there is 1:1 correspondence between the output
208   // classes of the model and each of the spannable tokens).
GetNumContextTokens()209   int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
210 
211   // Converts a label into a span of codepoint indices corresponding to it
212   // given output_tokens.
213   bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
214                    CodepointSpan* span) const;
215 
216   // Converts a span to the corresponding label given output_tokens.
217   bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
218                    const std::vector<Token>& output_tokens, int* label) const;
219 
220   // Converts a token span to the corresponding label.
221   int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
222 
223   // Returns the ratio of supported codepoints to total number of codepoints in
224   // the given token span.
225   float SupportedCodepointsRatio(const TokenSpan& token_span,
226                                  const std::vector<Token>& tokens) const;
227 
228   void PrepareIgnoredSpanBoundaryCodepoints();
229 
230   // Counts the number of span boundary codepoints. If count_from_beginning is
231   // True, the counting will start at the span_start iterator (inclusive) and at
232   // maximum end at span_end (exclusive). If count_from_beginning is True, the
233   // counting will start from span_end (exclusive) and end at span_start
234   // (inclusive).
235   int CountIgnoredSpanBoundaryCodepoints(
236       const UnicodeText::const_iterator& span_start,
237       const UnicodeText::const_iterator& span_end,
238       bool count_from_beginning) const;
239 
240   // Finds the center token index in tokens vector, using the method defined
241   // in options_.
242   int FindCenterToken(CodepointSpan span,
243                       const std::vector<Token>& tokens) const;
244 
245   // Removes all tokens from tokens that are not on a line (defined by calling
246   // SplitContext on the context) to which span points.
247   void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
248                                  std::vector<Token>* tokens) const;
249 
250   // Same as above but takes UnicodeText.
251   void StripTokensFromOtherLines(const UnicodeText& context_unicode,
252                                  CodepointSpan span,
253                                  std::vector<Token>* tokens) const;
254 
255   // Extracts the features of a token and appends them to the output vector.
256   // Uses the embedding cache to to avoid re-extracting the re-embedding the
257   // sparse features for the same token.
258   bool AppendTokenFeaturesWithCache(const Token& token,
259                                     CodepointSpan selection_span_for_feature,
260                                     const EmbeddingExecutor* embedding_executor,
261                                     EmbeddingCache* embedding_cache,
262                                     std::vector<float>* output_features) const;
263 
264  protected:
265   const TokenFeatureExtractor feature_extractor_;
266 
267   // Codepoint ranges that define what codepoints are supported by the model.
268   // NOTE: Must be sorted.
269   std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
270 
271  private:
272   // Set of codepoints that will be stripped from beginning and end of
273   // predicted spans.
274   std::set<int32> ignored_span_boundary_codepoints_;
275 
276   const FeatureProcessorOptions* const options_;
277 
278   // Mapping between token selection spans and labels ids.
279   std::map<TokenSpan, int> selection_to_label_;
280   std::vector<TokenSpan> label_to_selection_;
281 
282   // Mapping between collections and labels.
283   std::map<std::string, int> collection_to_label_;
284 
285   Tokenizer tokenizer_;
286 };
287 
288 }  // namespace libtextclassifier3
289 
290 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
291