1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Inference code for the text classification model.
18 
19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
20 #define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
21 
22 #include <memory>
23 #include <set>
24 #include <string>
25 #include <unordered_set>
26 #include <vector>
27 
28 #include "annotator/contact/contact-engine.h"
29 #include "annotator/datetime/parser.h"
30 #include "annotator/duration/duration.h"
31 #include "annotator/feature-processor.h"
32 #include "annotator/installed_app/installed-app-engine.h"
33 #include "annotator/knowledge/knowledge-engine.h"
34 #include "annotator/model-executor.h"
35 #include "annotator/model_generated.h"
36 #include "annotator/number/number.h"
37 #include "annotator/strip-unpaired-brackets.h"
38 #include "annotator/types.h"
39 #include "annotator/zlib-utils.h"
40 #include "utils/flatbuffers.h"
41 #include "utils/i18n/locale.h"
42 #include "utils/memory/mmap.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/zlib/zlib.h"
45 
46 namespace libtextclassifier3 {
47 
48 // Aliases for long enum values.
49 const AnnotationUsecase ANNOTATION_USECASE_SMART =
50     AnnotationUsecase_ANNOTATION_USECASE_SMART;
51 const AnnotationUsecase ANNOTATION_USECASE_RAW =
52     AnnotationUsecase_ANNOTATION_USECASE_RAW;
53 
54 struct SelectionOptions {
55   // Comma-separated list of locale specification for the input text (BCP 47
56   // tags).
57   std::string locales;
58 
59   // Comma-separated list of BCP 47 language tags.
60   std::string detected_text_language_tags;
61 
62   // Tailors the output annotations according to the specified use-case.
63   AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
64 
65   bool operator==(const SelectionOptions& other) const {
66     return this->locales == other.locales &&
67            this->annotation_usecase == other.annotation_usecase &&
68            this->detected_text_language_tags ==
69                other.detected_text_language_tags;
70   }
71 };
72 
73 struct ClassificationOptions {
74   // For parsing relative datetimes, the reference now time against which the
75   // relative datetimes get resolved.
76   // UTC milliseconds since epoch.
77   int64 reference_time_ms_utc = 0;
78 
79   // Timezone in which the input text was written (format as accepted by ICU).
80   std::string reference_timezone;
81 
82   // Comma-separated list of locale specification for the input text (BCP 47
83   // tags).
84   std::string locales;
85 
86   // Comma-separated list of language tags.
87   std::string detected_text_language_tags;
88 
89   // Tailors the output annotations according to the specified use-case.
90   AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
91 
92   bool operator==(const ClassificationOptions& other) const {
93     return this->reference_time_ms_utc == other.reference_time_ms_utc &&
94            this->reference_timezone == other.reference_timezone &&
95            this->locales == other.locales &&
96            this->detected_text_language_tags ==
97                other.detected_text_language_tags &&
98            this->annotation_usecase == other.annotation_usecase;
99   }
100 };
101 
102 struct AnnotationOptions {
103   // For parsing relative datetimes, the reference now time against which the
104   // relative datetimes get resolved.
105   // UTC milliseconds since epoch.
106   int64 reference_time_ms_utc = 0;
107 
108   // Timezone in which the input text was written (format as accepted by ICU).
109   std::string reference_timezone;
110 
111   // Comma-separated list of locale specification for the input text (BCP 47
112   // tags).
113   std::string locales;
114 
115   // Comma-separated list of language tags.
116   std::string detected_text_language_tags;
117 
118   // List of entity types that should be used for annotation.
119   std::unordered_set<std::string> entity_types;
120 
121   // If true, serialized_entity_data in the results is populated."
122   bool is_serialized_entity_data_enabled = false;
123 
124   // Tailors the output annotations according to the specified use-case.
125   AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
126 
127   bool operator==(const AnnotationOptions& other) const {
128     return this->reference_time_ms_utc == other.reference_time_ms_utc &&
129            this->reference_timezone == other.reference_timezone &&
130            this->locales == other.locales &&
131            this->detected_text_language_tags ==
132                other.detected_text_language_tags &&
133            this->annotation_usecase == other.annotation_usecase &&
134            this->is_serialized_entity_data_enabled ==
135                other.is_serialized_entity_data_enabled;
136   }
137 };
138 
139 // Holds TFLite interpreters for selection and classification models.
140 // NOTE: his class is not thread-safe, thus should NOT be re-used across
141 // threads.
142 class InterpreterManager {
143  public:
144   // The constructor can be called with nullptr for any of the executors, and is
145   // a defined behavior, as long as the corresponding *Interpreter() method is
146   // not called when the executor is null.
InterpreterManager(const ModelExecutor * selection_executor,const ModelExecutor * classification_executor)147   InterpreterManager(const ModelExecutor* selection_executor,
148                      const ModelExecutor* classification_executor)
149       : selection_executor_(selection_executor),
150         classification_executor_(classification_executor) {}
151 
152   // Gets or creates and caches an interpreter for the selection model.
153   tflite::Interpreter* SelectionInterpreter();
154 
155   // Gets or creates and caches an interpreter for the classification model.
156   tflite::Interpreter* ClassificationInterpreter();
157 
158  private:
159   const ModelExecutor* selection_executor_;
160   const ModelExecutor* classification_executor_;
161 
162   std::unique_ptr<tflite::Interpreter> selection_interpreter_;
163   std::unique_ptr<tflite::Interpreter> classification_interpreter_;
164 };
165 
166 // Stores entity types enabled for annotation, and provides operator() for
167 // checking whether a given entity type is enabled.
168 class EnabledEntityTypes {
169  public:
EnabledEntityTypes(const std::unordered_set<std::string> & entity_types)170   explicit EnabledEntityTypes(
171       const std::unordered_set<std::string>& entity_types)
172       : entity_types_(entity_types) {}
173 
operator()174   bool operator()(const std::string& entity_type) const {
175     return entity_types_.empty() ||
176            entity_types_.find(entity_type) != entity_types_.cend();
177   }
178 
179  private:
180   const std::unordered_set<std::string>& entity_types_;
181 };
182 
183 // A text processing model that provides text classification, annotation,
184 // selection suggestion for various types.
185 // NOTE: This class is not thread-safe.
186 class Annotator {
187  public:
188   static std::unique_ptr<Annotator> FromUnownedBuffer(
189       const char* buffer, int size, const UniLib* unilib = nullptr,
190       const CalendarLib* calendarlib = nullptr);
191   // Takes ownership of the mmap.
192   static std::unique_ptr<Annotator> FromScopedMmap(
193       std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
194       const CalendarLib* calendarlib = nullptr);
195   static std::unique_ptr<Annotator> FromScopedMmap(
196       std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
197       std::unique_ptr<CalendarLib> calendarlib);
198   static std::unique_ptr<Annotator> FromFileDescriptor(
199       int fd, int offset, int size, const UniLib* unilib = nullptr,
200       const CalendarLib* calendarlib = nullptr);
201   static std::unique_ptr<Annotator> FromFileDescriptor(
202       int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
203       std::unique_ptr<CalendarLib> calendarlib);
204   static std::unique_ptr<Annotator> FromFileDescriptor(
205       int fd, const UniLib* unilib = nullptr,
206       const CalendarLib* calendarlib = nullptr);
207   static std::unique_ptr<Annotator> FromFileDescriptor(
208       int fd, std::unique_ptr<UniLib> unilib,
209       std::unique_ptr<CalendarLib> calendarlib);
210   static std::unique_ptr<Annotator> FromPath(
211       const std::string& path, const UniLib* unilib = nullptr,
212       const CalendarLib* calendarlib = nullptr);
213   static std::unique_ptr<Annotator> FromPath(
214       const std::string& path, std::unique_ptr<UniLib> unilib,
215       std::unique_ptr<CalendarLib> calendarlib);
216 
217   // Returns true if the model is ready for use.
IsInitialized()218   bool IsInitialized() { return initialized_; }
219 
220   // Initializes the knowledge engine with the given config.
221   bool InitializeKnowledgeEngine(const std::string& serialized_config);
222 
223   // Initializes the contact engine with the given config.
224   bool InitializeContactEngine(const std::string& serialized_config);
225 
226   // Initializes the installed app engine with the given config.
227   bool InitializeInstalledAppEngine(const std::string& serialized_config);
228 
229   // Runs inference for given a context and current selection (i.e. index
230   // of the first and one past last selected characters (utf8 codepoint
231   // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
232   // beginning character and one past selection end character.
233   // Returns the original click_indices if an error occurs.
234   // NOTE: The selection indices are passed in and returned in terms of
235   // UTF8 codepoints (not bytes).
236   // Requires that the model is a smart selection model.
237   CodepointSpan SuggestSelection(
238       const std::string& context, CodepointSpan click_indices,
239       const SelectionOptions& options = SelectionOptions()) const;
240 
241   // Classifies the selected text given the context string.
242   // Returns an empty result if an error occurs.
243   std::vector<ClassificationResult> ClassifyText(
244       const std::string& context, CodepointSpan selection_indices,
245       const ClassificationOptions& options = ClassificationOptions()) const;
246 
247   // Annotates given input text. The annotations are sorted by their position
248   // in the context string and exclude spans classified as 'other'.
249   std::vector<AnnotatedSpan> Annotate(
250       const std::string& context,
251       const AnnotationOptions& options = AnnotationOptions()) const;
252 
253   // Looks up a knowledge entity by its id. If successful, populates the
254   // serialized knowledge result and returns true.
255   bool LookUpKnowledgeEntity(const std::string& id,
256                              std::string* serialized_knowledge_result) const;
257 
258   const Model* model() const;
259   const reflection::Schema* entity_data_schema() const;
260 
261   // Exposes the feature processor for tests and evaluations.
262   const FeatureProcessor* SelectionFeatureProcessorForTests() const;
263   const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
264 
265   // Exposes the date time parser for tests and evaluations.
266   const DatetimeParser* DatetimeParserForTests() const;
267 
268   static const std::string& kPhoneCollection;
269   static const std::string& kAddressCollection;
270   static const std::string& kDateCollection;
271   static const std::string& kUrlCollection;
272   static const std::string& kEmailCollection;
273 
274  protected:
275   struct ScoredChunk {
276     TokenSpan token_span;
277     float score;
278   };
279 
280   // Constructs and initializes text classifier from given model.
281   // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
282   Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
283             const UniLib* unilib, const CalendarLib* calendarlib);
284   Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
285             std::unique_ptr<UniLib> unilib,
286             std::unique_ptr<CalendarLib> calendarlib);
287 
288   // Constructs, validates and initializes text classifier from given model.
289   // Does not own the buffer that backs 'model'.
290   Annotator(const Model* model, const UniLib* unilib,
291             const CalendarLib* calendarlib);
292 
293   // Checks that model contains all required fields, and initializes internal
294   // datastructures.
295   void ValidateAndInitialize();
296 
297   // Initializes regular expressions for the regex model.
298   bool InitializeRegexModel(ZlibDecompressor* decompressor);
299 
300   // Resolves conflicts in the list of candidates by removing some overlapping
301   // ones. Returns indices of the surviving ones.
302   // NOTE: Assumes that the candidates are sorted according to their position in
303   // the span.
304   bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
305                         const std::string& context,
306                         const std::vector<Token>& cached_tokens,
307                         const std::vector<Locale>& detected_text_language_tags,
308                         AnnotationUsecase annotation_usecase,
309                         InterpreterManager* interpreter_manager,
310                         std::vector<int>* result) const;
311 
312   // Resolves one conflict between candidates on indices 'start_index'
313   // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
314   // indices to 'chosen_indices'. Returns false if a problem arises.
315   bool ResolveConflict(const std::string& context,
316                        const std::vector<Token>& cached_tokens,
317                        const std::vector<AnnotatedSpan>& candidates,
318                        const std::vector<Locale>& detected_text_language_tags,
319                        int start_index, int end_index,
320                        AnnotationUsecase annotation_usecase,
321                        InterpreterManager* interpreter_manager,
322                        std::vector<int>* chosen_indices) const;
323 
324   // Gets selection candidates from the ML model.
325   // Provides the tokens produced during tokenization of the context string for
326   // reuse.
327   bool ModelSuggestSelection(
328       const UnicodeText& context_unicode, CodepointSpan click_indices,
329       const std::vector<Locale>& detected_text_language_tags,
330       InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
331       std::vector<AnnotatedSpan>* result) const;
332 
333   // Classifies the selected text given the context string with the
334   // classification model.
335   // Returns true if no error occurred.
336   bool ModelClassifyText(
337       const std::string& context, const std::vector<Token>& cached_tokens,
338       const std::vector<Locale>& locales, CodepointSpan selection_indices,
339       InterpreterManager* interpreter_manager,
340       FeatureProcessor::EmbeddingCache* embedding_cache,
341       std::vector<ClassificationResult>* classification_results,
342       std::vector<Token>* tokens) const;
343 
344   // Same as above but doesn't output tokens.
345   bool ModelClassifyText(
346       const std::string& context, const std::vector<Token>& cached_tokens,
347       const std::vector<Locale>& detected_text_language_tags,
348       CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
349       FeatureProcessor::EmbeddingCache* embedding_cache,
350       std::vector<ClassificationResult>* classification_results) const;
351 
352   // Same as above but doesn't take cached tokens and doesn't output tokens.
353   bool ModelClassifyText(
354       const std::string& context,
355       const std::vector<Locale>& detected_text_language_tags,
356       CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
357       FeatureProcessor::EmbeddingCache* embedding_cache,
358       std::vector<ClassificationResult>* classification_results) const;
359 
360   // Returns a relative token span that represents how many tokens on the left
361   // from the selection and right from the selection are needed for the
362   // classifier input.
363   TokenSpan ClassifyTextUpperBoundNeededTokens() const;
364 
365   // Classifies the selected text with the regular expressions models.
366   // Returns true if no error happened, false otherwise.
367   bool RegexClassifyText(
368       const std::string& context, CodepointSpan selection_indices,
369       std::vector<ClassificationResult>* classification_result) const;
370 
371   // Classifies the selected text with the date time model.
372   // Returns true if no error happened, false otherwise.
373   bool DatetimeClassifyText(
374       const std::string& context, CodepointSpan selection_indices,
375       const ClassificationOptions& options,
376       std::vector<ClassificationResult>* classification_results) const;
377 
378   // Chunks given input text with the selection model and classifies the spans
379   // with the classification model.
380   // The annotations are sorted by their position in the context string and
381   // exclude spans classified as 'other'.
382   // Provides the tokens produced during tokenization of the context string for
383   // reuse.
384   bool ModelAnnotate(const std::string& context,
385                      const std::vector<Locale>& detected_text_language_tags,
386                      InterpreterManager* interpreter_manager,
387                      std::vector<Token>* tokens,
388                      std::vector<AnnotatedSpan>* result) const;
389 
390   // Groups the tokens into chunks. A chunk is a token span that should be the
391   // suggested selection when any of its contained tokens is clicked. The chunks
392   // are non-overlapping and are sorted by their position in the context string.
393   // "num_tokens" is the total number of tokens available (as this method does
394   // not need the actual vector of tokens).
395   // "span_of_interest" is a span of all the tokens that could be clicked.
396   // The resulting chunks all have to overlap with it and they cover this span
397   // completely. The first and last chunk might extend beyond it.
398   // The chunks vector is cleared before filling.
399   bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
400                   tflite::Interpreter* selection_interpreter,
401                   const CachedFeatures& cached_features,
402                   std::vector<TokenSpan>* chunks) const;
403 
404   // A helper method for ModelChunk(). It generates scored chunk candidates for
405   // a click context model.
406   // NOTE: The returned chunks can (and most likely do) overlap.
407   bool ModelClickContextScoreChunks(
408       int num_tokens, const TokenSpan& span_of_interest,
409       const CachedFeatures& cached_features,
410       tflite::Interpreter* selection_interpreter,
411       std::vector<ScoredChunk>* scored_chunks) const;
412 
413   // A helper method for ModelChunk(). It generates scored chunk candidates for
414   // a bounds-sensitive model.
415   // NOTE: The returned chunks can (and most likely do) overlap.
416   bool ModelBoundsSensitiveScoreChunks(
417       int num_tokens, const TokenSpan& span_of_interest,
418       const TokenSpan& inference_span, const CachedFeatures& cached_features,
419       tflite::Interpreter* selection_interpreter,
420       std::vector<ScoredChunk>* scored_chunks) const;
421 
422   // Produces chunks isolated by a set of regular expressions.
423   bool RegexChunk(const UnicodeText& context_unicode,
424                   const std::vector<int>& rules,
425                   std::vector<AnnotatedSpan>* result,
426                   bool is_serialized_entity_data_enabled) const;
427 
428   // Produces chunks from the datetime parser.
429   bool DatetimeChunk(const UnicodeText& context_unicode,
430                      int64 reference_time_ms_utc,
431                      const std::string& reference_timezone,
432                      const std::string& locales, ModeFlag mode,
433                      AnnotationUsecase annotation_usecase,
434                      bool is_serialized_entity_data_enabled,
435                      std::vector<AnnotatedSpan>* result) const;
436 
437   // Returns whether a classification should be filtered.
438   bool FilteredForAnnotation(const AnnotatedSpan& span) const;
439   bool FilteredForClassification(
440       const ClassificationResult& classification) const;
441   bool FilteredForSelection(const AnnotatedSpan& span) const;
442 
443   // Computes the selection boundaries from a regular expression match.
444   CodepointSpan ComputeSelectionBoundaries(
445       const UniLib::RegexMatcher* match,
446       const RegexModel_::Pattern* config) const;
447 
448   // Returns whether a regex pattern provides entity data from a match.
449   bool HasEntityData(const RegexModel_::Pattern* pattern) const;
450 
451   // Constructs and serializes entity data from regex matches.
452   bool SerializedEntityDataFromRegexMatch(
453       const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
454       std::string* serialized_entity_data) const;
455 
456   // Verifies a regex match and returns true if verification was successful.
457   bool VerifyRegexMatchCandidate(
458       const std::string& context,
459       const VerificationOptions* verification_options, const std::string& match,
460       const UniLib::RegexMatcher* matcher) const;
461 
462   const Model* model_;
463 
464   std::unique_ptr<const ModelExecutor> selection_executor_;
465   std::unique_ptr<const ModelExecutor> classification_executor_;
466   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
467 
468   std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
469   std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
470 
471   std::unique_ptr<const DatetimeParser> datetime_parser_;
472 
473  private:
474   struct CompiledRegexPattern {
475     const RegexModel_::Pattern* config;
476     std::unique_ptr<UniLib::RegexPattern> pattern;
477   };
478 
479   // Removes annotations the entity type of which is not in the set of enabled
480   // entity types.
481   void RemoveNotEnabledEntityTypes(
482       const EnabledEntityTypes& is_entity_type_enabled,
483       std::vector<AnnotatedSpan>* annotated_spans) const;
484 
485   std::unique_ptr<ScopedMmap> mmap_;
486   bool initialized_ = false;
487   bool enabled_for_annotation_ = false;
488   bool enabled_for_classification_ = false;
489   bool enabled_for_selection_ = false;
490   std::unordered_set<std::string> filtered_collections_annotation_;
491   std::unordered_set<std::string> filtered_collections_classification_;
492   std::unordered_set<std::string> filtered_collections_selection_;
493 
494   std::vector<CompiledRegexPattern> regex_patterns_;
495 
496   // Indices into regex_patterns_ for the different modes.
497   std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
498       selection_regex_patterns_;
499 
500   std::unique_ptr<UniLib> owned_unilib_;
501   const UniLib* unilib_;
502   std::unique_ptr<CalendarLib> owned_calendarlib_;
503   const CalendarLib* calendarlib_;
504 
505   std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
506   std::unique_ptr<const ContactEngine> contact_engine_;
507   std::unique_ptr<const InstalledAppEngine> installed_app_engine_;
508   std::unique_ptr<const NumberAnnotator> number_annotator_;
509   std::unique_ptr<const DurationAnnotator> duration_annotator_;
510 
511   // Builder for creating extra data.
512   const reflection::Schema* entity_data_schema_;
513   std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
514 
515   // Locales for which the entire model triggers.
516   std::vector<Locale> model_triggering_locales_;
517 
518   // Locales for which the ML model triggers.
519   std::vector<Locale> ml_model_triggering_locales_;
520 
521   // Locales that the dictionary classification support.
522   std::vector<Locale> dictionary_locales_;
523 };
524 
525 namespace internal {
526 
527 // Helper function, which if the initial 'span' contains only white-spaces,
528 // moves the selection to a single-codepoint selection on the left side
529 // of this block of white-space.
530 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
531                                             const UnicodeText& context_unicode,
532                                             const UniLib& unilib);
533 
534 // Copies tokens from 'cached_tokens' that are
535 // 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
536 // from the tokens that correspond to 'selection_indices'.
537 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
538                                     CodepointSpan selection_indices,
539                                     TokenSpan tokens_around_selection_to_copy);
540 }  // namespace internal
541 
542 // Interprets the buffer as a Model flatbuffer and returns it for reading.
543 const Model* ViewModel(const void* buffer, int size);
544 
545 // Opens model from given path and runs a function, passing the loaded Model
546 // flatbuffer as an argument.
547 //
548 // This is mainly useful if we don't want to pay the cost for the model
549 // initialization because we'll be only reading some flatbuffer values from the
550 // file.
551 template <typename ReturnType, typename Func>
VisitAnnotatorModel(const std::string & path,Func function)552 ReturnType VisitAnnotatorModel(const std::string& path, Func function) {
553   ScopedMmap mmap(path);
554   if (!mmap.handle().ok()) {
555     function(/*model=*/nullptr);
556   }
557   const Model* model =
558       ViewModel(mmap.handle().start(), mmap.handle().num_bytes());
559   return function(model);
560 }
561 
562 }  // namespace libtextclassifier3
563 
564 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
565