1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Inference code for the text classification model.
18
19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
20 #define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
21
22 #include <memory>
23 #include <set>
24 #include <string>
25 #include <unordered_set>
26 #include <vector>
27
28 #include "annotator/contact/contact-engine.h"
29 #include "annotator/datetime/parser.h"
30 #include "annotator/duration/duration.h"
31 #include "annotator/feature-processor.h"
32 #include "annotator/installed_app/installed-app-engine.h"
33 #include "annotator/knowledge/knowledge-engine.h"
34 #include "annotator/model-executor.h"
35 #include "annotator/model_generated.h"
36 #include "annotator/number/number.h"
37 #include "annotator/strip-unpaired-brackets.h"
38 #include "annotator/types.h"
39 #include "annotator/zlib-utils.h"
40 #include "utils/flatbuffers.h"
41 #include "utils/i18n/locale.h"
42 #include "utils/memory/mmap.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/zlib/zlib.h"
45
46 namespace libtextclassifier3 {
47
48 // Aliases for long enum values.
49 const AnnotationUsecase ANNOTATION_USECASE_SMART =
50 AnnotationUsecase_ANNOTATION_USECASE_SMART;
51 const AnnotationUsecase ANNOTATION_USECASE_RAW =
52 AnnotationUsecase_ANNOTATION_USECASE_RAW;
53
54 struct SelectionOptions {
55 // Comma-separated list of locale specification for the input text (BCP 47
56 // tags).
57 std::string locales;
58
59 // Comma-separated list of BCP 47 language tags.
60 std::string detected_text_language_tags;
61
62 // Tailors the output annotations according to the specified use-case.
63 AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
64
65 bool operator==(const SelectionOptions& other) const {
66 return this->locales == other.locales &&
67 this->annotation_usecase == other.annotation_usecase &&
68 this->detected_text_language_tags ==
69 other.detected_text_language_tags;
70 }
71 };
72
73 struct ClassificationOptions {
74 // For parsing relative datetimes, the reference now time against which the
75 // relative datetimes get resolved.
76 // UTC milliseconds since epoch.
77 int64 reference_time_ms_utc = 0;
78
79 // Timezone in which the input text was written (format as accepted by ICU).
80 std::string reference_timezone;
81
82 // Comma-separated list of locale specification for the input text (BCP 47
83 // tags).
84 std::string locales;
85
86 // Comma-separated list of language tags.
87 std::string detected_text_language_tags;
88
89 // Tailors the output annotations according to the specified use-case.
90 AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
91
92 bool operator==(const ClassificationOptions& other) const {
93 return this->reference_time_ms_utc == other.reference_time_ms_utc &&
94 this->reference_timezone == other.reference_timezone &&
95 this->locales == other.locales &&
96 this->detected_text_language_tags ==
97 other.detected_text_language_tags &&
98 this->annotation_usecase == other.annotation_usecase;
99 }
100 };
101
102 struct AnnotationOptions {
103 // For parsing relative datetimes, the reference now time against which the
104 // relative datetimes get resolved.
105 // UTC milliseconds since epoch.
106 int64 reference_time_ms_utc = 0;
107
108 // Timezone in which the input text was written (format as accepted by ICU).
109 std::string reference_timezone;
110
111 // Comma-separated list of locale specification for the input text (BCP 47
112 // tags).
113 std::string locales;
114
115 // Comma-separated list of language tags.
116 std::string detected_text_language_tags;
117
118 // List of entity types that should be used for annotation.
119 std::unordered_set<std::string> entity_types;
120
121 // If true, serialized_entity_data in the results is populated."
122 bool is_serialized_entity_data_enabled = false;
123
124 // Tailors the output annotations according to the specified use-case.
125 AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
126
127 bool operator==(const AnnotationOptions& other) const {
128 return this->reference_time_ms_utc == other.reference_time_ms_utc &&
129 this->reference_timezone == other.reference_timezone &&
130 this->locales == other.locales &&
131 this->detected_text_language_tags ==
132 other.detected_text_language_tags &&
133 this->annotation_usecase == other.annotation_usecase &&
134 this->is_serialized_entity_data_enabled ==
135 other.is_serialized_entity_data_enabled;
136 }
137 };
138
139 // Holds TFLite interpreters for selection and classification models.
140 // NOTE: his class is not thread-safe, thus should NOT be re-used across
141 // threads.
142 class InterpreterManager {
143 public:
144 // The constructor can be called with nullptr for any of the executors, and is
145 // a defined behavior, as long as the corresponding *Interpreter() method is
146 // not called when the executor is null.
InterpreterManager(const ModelExecutor * selection_executor,const ModelExecutor * classification_executor)147 InterpreterManager(const ModelExecutor* selection_executor,
148 const ModelExecutor* classification_executor)
149 : selection_executor_(selection_executor),
150 classification_executor_(classification_executor) {}
151
152 // Gets or creates and caches an interpreter for the selection model.
153 tflite::Interpreter* SelectionInterpreter();
154
155 // Gets or creates and caches an interpreter for the classification model.
156 tflite::Interpreter* ClassificationInterpreter();
157
158 private:
159 const ModelExecutor* selection_executor_;
160 const ModelExecutor* classification_executor_;
161
162 std::unique_ptr<tflite::Interpreter> selection_interpreter_;
163 std::unique_ptr<tflite::Interpreter> classification_interpreter_;
164 };
165
166 // Stores entity types enabled for annotation, and provides operator() for
167 // checking whether a given entity type is enabled.
168 class EnabledEntityTypes {
169 public:
EnabledEntityTypes(const std::unordered_set<std::string> & entity_types)170 explicit EnabledEntityTypes(
171 const std::unordered_set<std::string>& entity_types)
172 : entity_types_(entity_types) {}
173
operator()174 bool operator()(const std::string& entity_type) const {
175 return entity_types_.empty() ||
176 entity_types_.find(entity_type) != entity_types_.cend();
177 }
178
179 private:
180 const std::unordered_set<std::string>& entity_types_;
181 };
182
183 // A text processing model that provides text classification, annotation,
184 // selection suggestion for various types.
185 // NOTE: This class is not thread-safe.
186 class Annotator {
187 public:
188 static std::unique_ptr<Annotator> FromUnownedBuffer(
189 const char* buffer, int size, const UniLib* unilib = nullptr,
190 const CalendarLib* calendarlib = nullptr);
191 // Takes ownership of the mmap.
192 static std::unique_ptr<Annotator> FromScopedMmap(
193 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
194 const CalendarLib* calendarlib = nullptr);
195 static std::unique_ptr<Annotator> FromScopedMmap(
196 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
197 std::unique_ptr<CalendarLib> calendarlib);
198 static std::unique_ptr<Annotator> FromFileDescriptor(
199 int fd, int offset, int size, const UniLib* unilib = nullptr,
200 const CalendarLib* calendarlib = nullptr);
201 static std::unique_ptr<Annotator> FromFileDescriptor(
202 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
203 std::unique_ptr<CalendarLib> calendarlib);
204 static std::unique_ptr<Annotator> FromFileDescriptor(
205 int fd, const UniLib* unilib = nullptr,
206 const CalendarLib* calendarlib = nullptr);
207 static std::unique_ptr<Annotator> FromFileDescriptor(
208 int fd, std::unique_ptr<UniLib> unilib,
209 std::unique_ptr<CalendarLib> calendarlib);
210 static std::unique_ptr<Annotator> FromPath(
211 const std::string& path, const UniLib* unilib = nullptr,
212 const CalendarLib* calendarlib = nullptr);
213 static std::unique_ptr<Annotator> FromPath(
214 const std::string& path, std::unique_ptr<UniLib> unilib,
215 std::unique_ptr<CalendarLib> calendarlib);
216
217 // Returns true if the model is ready for use.
IsInitialized()218 bool IsInitialized() { return initialized_; }
219
220 // Initializes the knowledge engine with the given config.
221 bool InitializeKnowledgeEngine(const std::string& serialized_config);
222
223 // Initializes the contact engine with the given config.
224 bool InitializeContactEngine(const std::string& serialized_config);
225
226 // Initializes the installed app engine with the given config.
227 bool InitializeInstalledAppEngine(const std::string& serialized_config);
228
229 // Runs inference for given a context and current selection (i.e. index
230 // of the first and one past last selected characters (utf8 codepoint
231 // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
232 // beginning character and one past selection end character.
233 // Returns the original click_indices if an error occurs.
234 // NOTE: The selection indices are passed in and returned in terms of
235 // UTF8 codepoints (not bytes).
236 // Requires that the model is a smart selection model.
237 CodepointSpan SuggestSelection(
238 const std::string& context, CodepointSpan click_indices,
239 const SelectionOptions& options = SelectionOptions()) const;
240
241 // Classifies the selected text given the context string.
242 // Returns an empty result if an error occurs.
243 std::vector<ClassificationResult> ClassifyText(
244 const std::string& context, CodepointSpan selection_indices,
245 const ClassificationOptions& options = ClassificationOptions()) const;
246
247 // Annotates given input text. The annotations are sorted by their position
248 // in the context string and exclude spans classified as 'other'.
249 std::vector<AnnotatedSpan> Annotate(
250 const std::string& context,
251 const AnnotationOptions& options = AnnotationOptions()) const;
252
253 // Looks up a knowledge entity by its id. If successful, populates the
254 // serialized knowledge result and returns true.
255 bool LookUpKnowledgeEntity(const std::string& id,
256 std::string* serialized_knowledge_result) const;
257
258 const Model* model() const;
259 const reflection::Schema* entity_data_schema() const;
260
261 // Exposes the feature processor for tests and evaluations.
262 const FeatureProcessor* SelectionFeatureProcessorForTests() const;
263 const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
264
265 // Exposes the date time parser for tests and evaluations.
266 const DatetimeParser* DatetimeParserForTests() const;
267
268 static const std::string& kPhoneCollection;
269 static const std::string& kAddressCollection;
270 static const std::string& kDateCollection;
271 static const std::string& kUrlCollection;
272 static const std::string& kEmailCollection;
273
274 protected:
275 struct ScoredChunk {
276 TokenSpan token_span;
277 float score;
278 };
279
280 // Constructs and initializes text classifier from given model.
281 // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
282 Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
283 const UniLib* unilib, const CalendarLib* calendarlib);
284 Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
285 std::unique_ptr<UniLib> unilib,
286 std::unique_ptr<CalendarLib> calendarlib);
287
288 // Constructs, validates and initializes text classifier from given model.
289 // Does not own the buffer that backs 'model'.
290 Annotator(const Model* model, const UniLib* unilib,
291 const CalendarLib* calendarlib);
292
293 // Checks that model contains all required fields, and initializes internal
294 // datastructures.
295 void ValidateAndInitialize();
296
297 // Initializes regular expressions for the regex model.
298 bool InitializeRegexModel(ZlibDecompressor* decompressor);
299
300 // Resolves conflicts in the list of candidates by removing some overlapping
301 // ones. Returns indices of the surviving ones.
302 // NOTE: Assumes that the candidates are sorted according to their position in
303 // the span.
304 bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
305 const std::string& context,
306 const std::vector<Token>& cached_tokens,
307 const std::vector<Locale>& detected_text_language_tags,
308 AnnotationUsecase annotation_usecase,
309 InterpreterManager* interpreter_manager,
310 std::vector<int>* result) const;
311
312 // Resolves one conflict between candidates on indices 'start_index'
313 // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
314 // indices to 'chosen_indices'. Returns false if a problem arises.
315 bool ResolveConflict(const std::string& context,
316 const std::vector<Token>& cached_tokens,
317 const std::vector<AnnotatedSpan>& candidates,
318 const std::vector<Locale>& detected_text_language_tags,
319 int start_index, int end_index,
320 AnnotationUsecase annotation_usecase,
321 InterpreterManager* interpreter_manager,
322 std::vector<int>* chosen_indices) const;
323
324 // Gets selection candidates from the ML model.
325 // Provides the tokens produced during tokenization of the context string for
326 // reuse.
327 bool ModelSuggestSelection(
328 const UnicodeText& context_unicode, CodepointSpan click_indices,
329 const std::vector<Locale>& detected_text_language_tags,
330 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
331 std::vector<AnnotatedSpan>* result) const;
332
333 // Classifies the selected text given the context string with the
334 // classification model.
335 // Returns true if no error occurred.
336 bool ModelClassifyText(
337 const std::string& context, const std::vector<Token>& cached_tokens,
338 const std::vector<Locale>& locales, CodepointSpan selection_indices,
339 InterpreterManager* interpreter_manager,
340 FeatureProcessor::EmbeddingCache* embedding_cache,
341 std::vector<ClassificationResult>* classification_results,
342 std::vector<Token>* tokens) const;
343
344 // Same as above but doesn't output tokens.
345 bool ModelClassifyText(
346 const std::string& context, const std::vector<Token>& cached_tokens,
347 const std::vector<Locale>& detected_text_language_tags,
348 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
349 FeatureProcessor::EmbeddingCache* embedding_cache,
350 std::vector<ClassificationResult>* classification_results) const;
351
352 // Same as above but doesn't take cached tokens and doesn't output tokens.
353 bool ModelClassifyText(
354 const std::string& context,
355 const std::vector<Locale>& detected_text_language_tags,
356 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
357 FeatureProcessor::EmbeddingCache* embedding_cache,
358 std::vector<ClassificationResult>* classification_results) const;
359
360 // Returns a relative token span that represents how many tokens on the left
361 // from the selection and right from the selection are needed for the
362 // classifier input.
363 TokenSpan ClassifyTextUpperBoundNeededTokens() const;
364
365 // Classifies the selected text with the regular expressions models.
366 // Returns true if no error happened, false otherwise.
367 bool RegexClassifyText(
368 const std::string& context, CodepointSpan selection_indices,
369 std::vector<ClassificationResult>* classification_result) const;
370
371 // Classifies the selected text with the date time model.
372 // Returns true if no error happened, false otherwise.
373 bool DatetimeClassifyText(
374 const std::string& context, CodepointSpan selection_indices,
375 const ClassificationOptions& options,
376 std::vector<ClassificationResult>* classification_results) const;
377
378 // Chunks given input text with the selection model and classifies the spans
379 // with the classification model.
380 // The annotations are sorted by their position in the context string and
381 // exclude spans classified as 'other'.
382 // Provides the tokens produced during tokenization of the context string for
383 // reuse.
384 bool ModelAnnotate(const std::string& context,
385 const std::vector<Locale>& detected_text_language_tags,
386 InterpreterManager* interpreter_manager,
387 std::vector<Token>* tokens,
388 std::vector<AnnotatedSpan>* result) const;
389
390 // Groups the tokens into chunks. A chunk is a token span that should be the
391 // suggested selection when any of its contained tokens is clicked. The chunks
392 // are non-overlapping and are sorted by their position in the context string.
393 // "num_tokens" is the total number of tokens available (as this method does
394 // not need the actual vector of tokens).
395 // "span_of_interest" is a span of all the tokens that could be clicked.
396 // The resulting chunks all have to overlap with it and they cover this span
397 // completely. The first and last chunk might extend beyond it.
398 // The chunks vector is cleared before filling.
399 bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
400 tflite::Interpreter* selection_interpreter,
401 const CachedFeatures& cached_features,
402 std::vector<TokenSpan>* chunks) const;
403
404 // A helper method for ModelChunk(). It generates scored chunk candidates for
405 // a click context model.
406 // NOTE: The returned chunks can (and most likely do) overlap.
407 bool ModelClickContextScoreChunks(
408 int num_tokens, const TokenSpan& span_of_interest,
409 const CachedFeatures& cached_features,
410 tflite::Interpreter* selection_interpreter,
411 std::vector<ScoredChunk>* scored_chunks) const;
412
413 // A helper method for ModelChunk(). It generates scored chunk candidates for
414 // a bounds-sensitive model.
415 // NOTE: The returned chunks can (and most likely do) overlap.
416 bool ModelBoundsSensitiveScoreChunks(
417 int num_tokens, const TokenSpan& span_of_interest,
418 const TokenSpan& inference_span, const CachedFeatures& cached_features,
419 tflite::Interpreter* selection_interpreter,
420 std::vector<ScoredChunk>* scored_chunks) const;
421
422 // Produces chunks isolated by a set of regular expressions.
423 bool RegexChunk(const UnicodeText& context_unicode,
424 const std::vector<int>& rules,
425 std::vector<AnnotatedSpan>* result,
426 bool is_serialized_entity_data_enabled) const;
427
428 // Produces chunks from the datetime parser.
429 bool DatetimeChunk(const UnicodeText& context_unicode,
430 int64 reference_time_ms_utc,
431 const std::string& reference_timezone,
432 const std::string& locales, ModeFlag mode,
433 AnnotationUsecase annotation_usecase,
434 bool is_serialized_entity_data_enabled,
435 std::vector<AnnotatedSpan>* result) const;
436
437 // Returns whether a classification should be filtered.
438 bool FilteredForAnnotation(const AnnotatedSpan& span) const;
439 bool FilteredForClassification(
440 const ClassificationResult& classification) const;
441 bool FilteredForSelection(const AnnotatedSpan& span) const;
442
443 // Computes the selection boundaries from a regular expression match.
444 CodepointSpan ComputeSelectionBoundaries(
445 const UniLib::RegexMatcher* match,
446 const RegexModel_::Pattern* config) const;
447
448 // Returns whether a regex pattern provides entity data from a match.
449 bool HasEntityData(const RegexModel_::Pattern* pattern) const;
450
451 // Constructs and serializes entity data from regex matches.
452 bool SerializedEntityDataFromRegexMatch(
453 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
454 std::string* serialized_entity_data) const;
455
456 // Verifies a regex match and returns true if verification was successful.
457 bool VerifyRegexMatchCandidate(
458 const std::string& context,
459 const VerificationOptions* verification_options, const std::string& match,
460 const UniLib::RegexMatcher* matcher) const;
461
462 const Model* model_;
463
464 std::unique_ptr<const ModelExecutor> selection_executor_;
465 std::unique_ptr<const ModelExecutor> classification_executor_;
466 std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
467
468 std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
469 std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
470
471 std::unique_ptr<const DatetimeParser> datetime_parser_;
472
473 private:
474 struct CompiledRegexPattern {
475 const RegexModel_::Pattern* config;
476 std::unique_ptr<UniLib::RegexPattern> pattern;
477 };
478
479 // Removes annotations the entity type of which is not in the set of enabled
480 // entity types.
481 void RemoveNotEnabledEntityTypes(
482 const EnabledEntityTypes& is_entity_type_enabled,
483 std::vector<AnnotatedSpan>* annotated_spans) const;
484
485 std::unique_ptr<ScopedMmap> mmap_;
486 bool initialized_ = false;
487 bool enabled_for_annotation_ = false;
488 bool enabled_for_classification_ = false;
489 bool enabled_for_selection_ = false;
490 std::unordered_set<std::string> filtered_collections_annotation_;
491 std::unordered_set<std::string> filtered_collections_classification_;
492 std::unordered_set<std::string> filtered_collections_selection_;
493
494 std::vector<CompiledRegexPattern> regex_patterns_;
495
496 // Indices into regex_patterns_ for the different modes.
497 std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
498 selection_regex_patterns_;
499
500 std::unique_ptr<UniLib> owned_unilib_;
501 const UniLib* unilib_;
502 std::unique_ptr<CalendarLib> owned_calendarlib_;
503 const CalendarLib* calendarlib_;
504
505 std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
506 std::unique_ptr<const ContactEngine> contact_engine_;
507 std::unique_ptr<const InstalledAppEngine> installed_app_engine_;
508 std::unique_ptr<const NumberAnnotator> number_annotator_;
509 std::unique_ptr<const DurationAnnotator> duration_annotator_;
510
511 // Builder for creating extra data.
512 const reflection::Schema* entity_data_schema_;
513 std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
514
515 // Locales for which the entire model triggers.
516 std::vector<Locale> model_triggering_locales_;
517
518 // Locales for which the ML model triggers.
519 std::vector<Locale> ml_model_triggering_locales_;
520
521 // Locales that the dictionary classification support.
522 std::vector<Locale> dictionary_locales_;
523 };
524
525 namespace internal {
526
527 // Helper function, which if the initial 'span' contains only white-spaces,
528 // moves the selection to a single-codepoint selection on the left side
529 // of this block of white-space.
530 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
531 const UnicodeText& context_unicode,
532 const UniLib& unilib);
533
534 // Copies tokens from 'cached_tokens' that are
535 // 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
536 // from the tokens that correspond to 'selection_indices'.
537 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
538 CodepointSpan selection_indices,
539 TokenSpan tokens_around_selection_to_copy);
540 } // namespace internal
541
542 // Interprets the buffer as a Model flatbuffer and returns it for reading.
543 const Model* ViewModel(const void* buffer, int size);
544
545 // Opens model from given path and runs a function, passing the loaded Model
546 // flatbuffer as an argument.
547 //
548 // This is mainly useful if we don't want to pay the cost for the model
549 // initialization because we'll be only reading some flatbuffer values from the
550 // file.
551 template <typename ReturnType, typename Func>
VisitAnnotatorModel(const std::string & path,Func function)552 ReturnType VisitAnnotatorModel(const std::string& path, Func function) {
553 ScopedMmap mmap(path);
554 if (!mmap.handle().ok()) {
555 function(/*model=*/nullptr);
556 }
557 const Model* model =
558 ViewModel(mmap.handle().start(), mmap.handle().num_bytes());
559 return function(model);
560 }
561
562 } // namespace libtextclassifier3
563
564 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
565