1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/annotator.h"
18
19 #include <algorithm>
20 #include <cctype>
21 #include <cmath>
22 #include <iterator>
23 #include <numeric>
24 #include <unordered_map>
25
26 #include "annotator/collections.h"
27 #include "annotator/model_generated.h"
28 #include "annotator/types.h"
29 #include "utils/base/logging.h"
30 #include "utils/checksum.h"
31 #include "utils/math/softmax.h"
32 #include "utils/regex-match.h"
33 #include "utils/utf8/unicodetext.h"
34 #include "utils/zlib/zlib_regex.h"
35
36
37 namespace libtextclassifier3 {
38
39 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
40
41 const std::string& Annotator::kPhoneCollection =
__anon0b9b53900102() 42 *[]() { return new std::string("phone"); }();
43 const std::string& Annotator::kAddressCollection =
__anon0b9b53900202() 44 *[]() { return new std::string("address"); }();
45 const std::string& Annotator::kDateCollection =
__anon0b9b53900302() 46 *[]() { return new std::string("date"); }();
47 const std::string& Annotator::kUrlCollection =
__anon0b9b53900402() 48 *[]() { return new std::string("url"); }();
49 const std::string& Annotator::kEmailCollection =
__anon0b9b53900502() 50 *[]() { return new std::string("email"); }();
51
52 namespace {
LoadAndVerifyModel(const void * addr,int size)53 const Model* LoadAndVerifyModel(const void* addr, int size) {
54 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
55 if (VerifyModelBuffer(verifier)) {
56 return GetModel(addr);
57 } else {
58 return nullptr;
59 }
60 }
61
62 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
63 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)64 const UniLib* MaybeCreateUnilib(const UniLib* lib,
65 std::unique_ptr<UniLib>* owned_lib) {
66 if (lib) {
67 return lib;
68 } else {
69 owned_lib->reset(new UniLib);
70 return owned_lib->get();
71 }
72 }
73
74 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)75 const CalendarLib* MaybeCreateCalendarlib(
76 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
77 if (lib) {
78 return lib;
79 } else {
80 owned_lib->reset(new CalendarLib);
81 return owned_lib->get();
82 }
83 }
84
85 } // namespace
86
SelectionInterpreter()87 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
88 if (!selection_interpreter_) {
89 TC3_CHECK(selection_executor_);
90 selection_interpreter_ = selection_executor_->CreateInterpreter();
91 if (!selection_interpreter_) {
92 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
93 }
94 }
95 return selection_interpreter_.get();
96 }
97
ClassificationInterpreter()98 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
99 if (!classification_interpreter_) {
100 TC3_CHECK(classification_executor_);
101 classification_interpreter_ = classification_executor_->CreateInterpreter();
102 if (!classification_interpreter_) {
103 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
104 }
105 }
106 return classification_interpreter_.get();
107 }
108
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)109 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
110 const char* buffer, int size, const UniLib* unilib,
111 const CalendarLib* calendarlib) {
112 const Model* model = LoadAndVerifyModel(buffer, size);
113 if (model == nullptr) {
114 return nullptr;
115 }
116
117 auto classifier =
118 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
119 if (!classifier->IsInitialized()) {
120 return nullptr;
121 }
122
123 return classifier;
124 }
125
126
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)127 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
128 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
129 const CalendarLib* calendarlib) {
130 if (!(*mmap)->handle().ok()) {
131 TC3_VLOG(1) << "Mmap failed.";
132 return nullptr;
133 }
134
135 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
136 (*mmap)->handle().num_bytes());
137 if (!model) {
138 TC3_LOG(ERROR) << "Model verification failed.";
139 return nullptr;
140 }
141
142 auto classifier = std::unique_ptr<Annotator>(
143 new Annotator(mmap, model, unilib, calendarlib));
144 if (!classifier->IsInitialized()) {
145 return nullptr;
146 }
147
148 return classifier;
149 }
150
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)151 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
152 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
153 std::unique_ptr<CalendarLib> calendarlib) {
154 if (!(*mmap)->handle().ok()) {
155 TC3_VLOG(1) << "Mmap failed.";
156 return nullptr;
157 }
158
159 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
160 (*mmap)->handle().num_bytes());
161 if (model == nullptr) {
162 TC3_LOG(ERROR) << "Model verification failed.";
163 return nullptr;
164 }
165
166 auto classifier = std::unique_ptr<Annotator>(
167 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
168 if (!classifier->IsInitialized()) {
169 return nullptr;
170 }
171
172 return classifier;
173 }
174
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)175 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
176 int fd, int offset, int size, const UniLib* unilib,
177 const CalendarLib* calendarlib) {
178 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
179 return FromScopedMmap(&mmap, unilib, calendarlib);
180 }
181
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)182 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
183 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
184 std::unique_ptr<CalendarLib> calendarlib) {
185 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
186 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
187 }
188
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)189 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
190 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
191 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
192 return FromScopedMmap(&mmap, unilib, calendarlib);
193 }
194
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)195 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
196 int fd, std::unique_ptr<UniLib> unilib,
197 std::unique_ptr<CalendarLib> calendarlib) {
198 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
199 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
200 }
201
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)202 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
203 const UniLib* unilib,
204 const CalendarLib* calendarlib) {
205 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
206 return FromScopedMmap(&mmap, unilib, calendarlib);
207 }
208
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)209 std::unique_ptr<Annotator> Annotator::FromPath(
210 const std::string& path, std::unique_ptr<UniLib> unilib,
211 std::unique_ptr<CalendarLib> calendarlib) {
212 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
213 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
214 }
215
Annotator(std::unique_ptr<ScopedMmap> * mmap,const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)216 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
217 const UniLib* unilib, const CalendarLib* calendarlib)
218 : model_(model),
219 mmap_(std::move(*mmap)),
220 owned_unilib_(nullptr),
221 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
222 owned_calendarlib_(nullptr),
223 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
224 ValidateAndInitialize();
225 }
226
Annotator(std::unique_ptr<ScopedMmap> * mmap,const Model * model,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)227 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
228 std::unique_ptr<UniLib> unilib,
229 std::unique_ptr<CalendarLib> calendarlib)
230 : model_(model),
231 mmap_(std::move(*mmap)),
232 owned_unilib_(std::move(unilib)),
233 unilib_(owned_unilib_.get()),
234 owned_calendarlib_(std::move(calendarlib)),
235 calendarlib_(owned_calendarlib_.get()) {
236 ValidateAndInitialize();
237 }
238
Annotator(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)239 Annotator::Annotator(const Model* model, const UniLib* unilib,
240 const CalendarLib* calendarlib)
241 : model_(model),
242 owned_unilib_(nullptr),
243 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
244 owned_calendarlib_(nullptr),
245 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
246 ValidateAndInitialize();
247 }
248
ValidateAndInitialize()249 void Annotator::ValidateAndInitialize() {
250 initialized_ = false;
251
252 if (model_ == nullptr) {
253 TC3_LOG(ERROR) << "No model specified.";
254 return;
255 }
256
257 const bool model_enabled_for_annotation =
258 (model_->triggering_options() != nullptr &&
259 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
260 const bool model_enabled_for_classification =
261 (model_->triggering_options() != nullptr &&
262 (model_->triggering_options()->enabled_modes() &
263 ModeFlag_CLASSIFICATION));
264 const bool model_enabled_for_selection =
265 (model_->triggering_options() != nullptr &&
266 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
267
268 // Annotation requires the selection model.
269 if (model_enabled_for_annotation || model_enabled_for_selection) {
270 if (!model_->selection_options()) {
271 TC3_LOG(ERROR) << "No selection options.";
272 return;
273 }
274 if (!model_->selection_feature_options()) {
275 TC3_LOG(ERROR) << "No selection feature options.";
276 return;
277 }
278 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
279 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
280 return;
281 }
282 if (!model_->selection_model()) {
283 TC3_LOG(ERROR) << "No selection model.";
284 return;
285 }
286 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
287 if (!selection_executor_) {
288 TC3_LOG(ERROR) << "Could not initialize selection executor.";
289 return;
290 }
291 selection_feature_processor_.reset(
292 new FeatureProcessor(model_->selection_feature_options(), unilib_));
293 }
294
295 // Annotation requires the classification model for conflict resolution and
296 // scoring.
297 // Selection requires the classification model for conflict resolution.
298 if (model_enabled_for_annotation || model_enabled_for_classification ||
299 model_enabled_for_selection) {
300 if (!model_->classification_options()) {
301 TC3_LOG(ERROR) << "No classification options.";
302 return;
303 }
304
305 if (!model_->classification_feature_options()) {
306 TC3_LOG(ERROR) << "No classification feature options.";
307 return;
308 }
309
310 if (!model_->classification_feature_options()
311 ->bounds_sensitive_features()) {
312 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
313 return;
314 }
315 if (!model_->classification_model()) {
316 TC3_LOG(ERROR) << "No clf model.";
317 return;
318 }
319
320 classification_executor_ =
321 ModelExecutor::FromBuffer(model_->classification_model());
322 if (!classification_executor_) {
323 TC3_LOG(ERROR) << "Could not initialize classification executor.";
324 return;
325 }
326
327 classification_feature_processor_.reset(new FeatureProcessor(
328 model_->classification_feature_options(), unilib_));
329 }
330
331 // The embeddings need to be specified if the model is to be used for
332 // classification or selection.
333 if (model_enabled_for_annotation || model_enabled_for_classification ||
334 model_enabled_for_selection) {
335 if (!model_->embedding_model()) {
336 TC3_LOG(ERROR) << "No embedding model.";
337 return;
338 }
339
340 // Check that the embedding size of the selection and classification model
341 // matches, as they are using the same embeddings.
342 if (model_enabled_for_selection &&
343 (model_->selection_feature_options()->embedding_size() !=
344 model_->classification_feature_options()->embedding_size() ||
345 model_->selection_feature_options()->embedding_quantization_bits() !=
346 model_->classification_feature_options()
347 ->embedding_quantization_bits())) {
348 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
349 return;
350 }
351
352 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
353 model_->embedding_model(),
354 model_->classification_feature_options()->embedding_size(),
355 model_->classification_feature_options()->embedding_quantization_bits(),
356 model_->embedding_pruning_mask());
357 if (!embedding_executor_) {
358 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
359 return;
360 }
361 }
362
363 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
364 if (model_->regex_model()) {
365 if (!InitializeRegexModel(decompressor.get())) {
366 TC3_LOG(ERROR) << "Could not initialize regex model.";
367 return;
368 }
369 }
370
371 if (model_->datetime_model()) {
372 datetime_parser_ = DatetimeParser::Instance(
373 model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
374 if (!datetime_parser_) {
375 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
376 return;
377 }
378 }
379
380 if (model_->output_options()) {
381 if (model_->output_options()->filtered_collections_annotation()) {
382 for (const auto collection :
383 *model_->output_options()->filtered_collections_annotation()) {
384 filtered_collections_annotation_.insert(collection->str());
385 }
386 }
387 if (model_->output_options()->filtered_collections_classification()) {
388 for (const auto collection :
389 *model_->output_options()->filtered_collections_classification()) {
390 filtered_collections_classification_.insert(collection->str());
391 }
392 }
393 if (model_->output_options()->filtered_collections_selection()) {
394 for (const auto collection :
395 *model_->output_options()->filtered_collections_selection()) {
396 filtered_collections_selection_.insert(collection->str());
397 }
398 }
399 }
400
401 if (model_->number_annotator_options() &&
402 model_->number_annotator_options()->enabled()) {
403 if (selection_feature_processor_ == nullptr) {
404 TC3_LOG(ERROR)
405 << "Could not initialize NumberAnnotator without a feature processor";
406 return;
407 }
408
409 number_annotator_.reset(
410 new NumberAnnotator(model_->number_annotator_options(),
411 selection_feature_processor_.get()));
412 }
413
414 if (model_->duration_annotator_options() &&
415 model_->duration_annotator_options()->enabled()) {
416 duration_annotator_.reset(
417 new DurationAnnotator(model_->duration_annotator_options(),
418 selection_feature_processor_.get()));
419 }
420
421 if (model_->entity_data_schema()) {
422 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
423 model_->entity_data_schema()->Data(),
424 model_->entity_data_schema()->size());
425 if (entity_data_schema_ == nullptr) {
426 TC3_LOG(ERROR) << "Could not load entity data schema data.";
427 return;
428 }
429
430 entity_data_builder_.reset(
431 new ReflectiveFlatbufferBuilder(entity_data_schema_));
432 } else {
433 entity_data_schema_ = nullptr;
434 entity_data_builder_ = nullptr;
435 }
436
437 if (model_->triggering_locales() &&
438 !ParseLocales(model_->triggering_locales()->c_str(),
439 &model_triggering_locales_)) {
440 TC3_LOG(ERROR) << "Could not parse model supported locales.";
441 return;
442 }
443
444 if (model_->triggering_options() != nullptr &&
445 model_->triggering_options()->locales() != nullptr &&
446 !ParseLocales(model_->triggering_options()->locales()->c_str(),
447 &ml_model_triggering_locales_)) {
448 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
449 return;
450 }
451
452 if (model_->triggering_options() != nullptr &&
453 model_->triggering_options()->dictionary_locales() != nullptr &&
454 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
455 &dictionary_locales_)) {
456 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
457 return;
458 }
459
460 initialized_ = true;
461 }
462
InitializeRegexModel(ZlibDecompressor * decompressor)463 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
464 if (!model_->regex_model()->patterns()) {
465 return true;
466 }
467
468 // Initialize pattern recognizers.
469 int regex_pattern_id = 0;
470 for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
471 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
472 UncompressMakeRegexPattern(
473 *unilib_, regex_pattern->pattern(),
474 regex_pattern->compressed_pattern(),
475 model_->regex_model()->lazy_regex_compilation(), decompressor);
476 if (!compiled_pattern) {
477 TC3_LOG(INFO) << "Failed to load regex pattern";
478 return false;
479 }
480
481 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
482 annotation_regex_patterns_.push_back(regex_pattern_id);
483 }
484 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
485 classification_regex_patterns_.push_back(regex_pattern_id);
486 }
487 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
488 selection_regex_patterns_.push_back(regex_pattern_id);
489 }
490 regex_patterns_.push_back({
491 regex_pattern,
492 std::move(compiled_pattern),
493 });
494 ++regex_pattern_id;
495 }
496
497 return true;
498 }
499
InitializeKnowledgeEngine(const std::string & serialized_config)500 bool Annotator::InitializeKnowledgeEngine(
501 const std::string& serialized_config) {
502 std::unique_ptr<KnowledgeEngine> knowledge_engine(
503 new KnowledgeEngine(unilib_));
504 if (!knowledge_engine->Initialize(serialized_config)) {
505 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
506 return false;
507 }
508 knowledge_engine_ = std::move(knowledge_engine);
509 return true;
510 }
511
InitializeContactEngine(const std::string & serialized_config)512 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
513 std::unique_ptr<ContactEngine> contact_engine(
514 new ContactEngine(selection_feature_processor_.get(), unilib_));
515 if (!contact_engine->Initialize(serialized_config)) {
516 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
517 return false;
518 }
519 contact_engine_ = std::move(contact_engine);
520 return true;
521 }
522
InitializeInstalledAppEngine(const std::string & serialized_config)523 bool Annotator::InitializeInstalledAppEngine(
524 const std::string& serialized_config) {
525 std::unique_ptr<InstalledAppEngine> installed_app_engine(
526 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
527 if (!installed_app_engine->Initialize(serialized_config)) {
528 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
529 return false;
530 }
531 installed_app_engine_ = std::move(installed_app_engine);
532 return true;
533 }
534
535 namespace {
536
CountDigits(const std::string & str,CodepointSpan selection_indices)537 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
538 int count = 0;
539 int i = 0;
540 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
541 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
542 if (i >= selection_indices.first && i < selection_indices.second &&
543 isdigit(*it)) {
544 ++count;
545 }
546 }
547 return count;
548 }
549
550 } // namespace
551
552 namespace internal {
553 // Helper function, which if the initial 'span' contains only white-spaces,
554 // moves the selection to a single-codepoint selection on a left or right side
555 // of this space.
SnapLeftIfWhitespaceSelection(CodepointSpan span,const UnicodeText & context_unicode,const UniLib & unilib)556 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
557 const UnicodeText& context_unicode,
558 const UniLib& unilib) {
559 TC3_CHECK(ValidNonEmptySpan(span));
560
561 UnicodeText::const_iterator it;
562
563 // Check that the current selection is all whitespaces.
564 it = context_unicode.begin();
565 std::advance(it, span.first);
566 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
567 if (!unilib.IsWhitespace(*it)) {
568 return span;
569 }
570 }
571
572 CodepointSpan result;
573
574 // Try moving left.
575 result = span;
576 it = context_unicode.begin();
577 std::advance(it, span.first);
578 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
579 --result.first;
580 --it;
581 }
582 result.second = result.first + 1;
583 if (!unilib.IsWhitespace(*it)) {
584 return result;
585 }
586
587 // If moving left didn't find a non-whitespace character, just return the
588 // original span.
589 return span;
590 }
591 } // namespace internal
592
FilteredForAnnotation(const AnnotatedSpan & span) const593 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
594 return !span.classification.empty() &&
595 filtered_collections_annotation_.find(
596 span.classification[0].collection) !=
597 filtered_collections_annotation_.end();
598 }
599
FilteredForClassification(const ClassificationResult & classification) const600 bool Annotator::FilteredForClassification(
601 const ClassificationResult& classification) const {
602 return filtered_collections_classification_.find(classification.collection) !=
603 filtered_collections_classification_.end();
604 }
605
FilteredForSelection(const AnnotatedSpan & span) const606 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
607 return !span.classification.empty() &&
608 filtered_collections_selection_.find(
609 span.classification[0].collection) !=
610 filtered_collections_selection_.end();
611 }
612
613 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)614 inline bool ClassifiedAsOther(
615 const std::vector<ClassificationResult>& classification) {
616 return !classification.empty() &&
617 classification[0].collection == Collections::Other();
618 }
619
GetPriorityScore(const std::vector<ClassificationResult> & classification)620 float GetPriorityScore(
621 const std::vector<ClassificationResult>& classification) {
622 if (!classification.empty() && !ClassifiedAsOther(classification)) {
623 return classification[0].priority_score;
624 } else {
625 return -1.0;
626 }
627 }
628 } // namespace
629
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const630 bool Annotator::VerifyRegexMatchCandidate(
631 const std::string& context, const VerificationOptions* verification_options,
632 const std::string& match, const UniLib::RegexMatcher* matcher) const {
633 if (verification_options == nullptr) {
634 return true;
635 }
636 if (verification_options->verify_luhn_checksum() &&
637 !VerifyLuhnChecksum(match)) {
638 return false;
639 }
640 const int lua_verifier = verification_options->lua_verifier();
641 if (lua_verifier >= 0) {
642 if (model_->regex_model()->lua_verifier() == nullptr ||
643 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
644 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
645 return false;
646 }
647 return VerifyMatch(
648 context, matcher,
649 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
650 }
651 return true;
652 }
653
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const654 CodepointSpan Annotator::SuggestSelection(
655 const std::string& context, CodepointSpan click_indices,
656 const SelectionOptions& options) const {
657 CodepointSpan original_click_indices = click_indices;
658 if (!initialized_) {
659 TC3_LOG(ERROR) << "Not initialized";
660 return original_click_indices;
661 }
662 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
663 return original_click_indices;
664 }
665
666 std::vector<Locale> detected_text_language_tags;
667 if (!ParseLocales(options.detected_text_language_tags,
668 &detected_text_language_tags)) {
669 TC3_LOG(WARNING)
670 << "Failed to parse the detected_text_language_tags in options: "
671 << options.detected_text_language_tags;
672 }
673 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
674 model_triggering_locales_,
675 /*default_value=*/true)) {
676 return original_click_indices;
677 }
678
679 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
680 /*do_copy=*/false);
681
682 if (!context_unicode.is_valid()) {
683 return original_click_indices;
684 }
685
686 const int context_codepoint_size = context_unicode.size_codepoints();
687
688 if (click_indices.first < 0 || click_indices.second < 0 ||
689 click_indices.first >= context_codepoint_size ||
690 click_indices.second > context_codepoint_size ||
691 click_indices.first >= click_indices.second) {
692 TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
693 << click_indices.first << " " << click_indices.second;
694 return original_click_indices;
695 }
696
697 if (model_->snap_whitespace_selections()) {
698 // We want to expand a purely white-space selection to a multi-selection it
699 // would've been part of. But with this feature disabled we would do a no-
700 // op, because no token is found. Therefore, we need to modify the
701 // 'click_indices' a bit to include a part of the token, so that the click-
702 // finding logic finds the clicked token correctly. This modification is
703 // done by the following function. Note, that it's enough to check the left
704 // side of the current selection, because if the white-space is a part of a
705 // multi-selection, necessarily both tokens - on the left and the right
706 // sides need to be selected. Thus snapping only to the left is sufficient
707 // (there's a check at the bottom that makes sure that if we snap to the
708 // left token but the result does not contain the initial white-space,
709 // returns the original indices).
710 click_indices = internal::SnapLeftIfWhitespaceSelection(
711 click_indices, context_unicode, *unilib_);
712 }
713
714 std::vector<AnnotatedSpan> candidates;
715 InterpreterManager interpreter_manager(selection_executor_.get(),
716 classification_executor_.get());
717 std::vector<Token> tokens;
718 if (!ModelSuggestSelection(context_unicode, click_indices,
719 detected_text_language_tags, &interpreter_manager,
720 &tokens, &candidates)) {
721 TC3_LOG(ERROR) << "Model suggest selection failed.";
722 return original_click_indices;
723 }
724 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
725 /*is_serialized_entity_data_enabled=*/false)) {
726 TC3_LOG(ERROR) << "Regex suggest selection failed.";
727 return original_click_indices;
728 }
729 if (!DatetimeChunk(
730 UTF8ToUnicodeText(context, /*do_copy=*/false),
731 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
732 options.locales, ModeFlag_SELECTION, options.annotation_usecase,
733 /*is_serialized_entity_data_enabled=*/false, &candidates)) {
734 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
735 return original_click_indices;
736 }
737 if (knowledge_engine_ != nullptr &&
738 !knowledge_engine_->Chunk(context, &candidates)) {
739 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
740 return original_click_indices;
741 }
742 if (contact_engine_ != nullptr &&
743 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
744 TC3_LOG(ERROR) << "Contact suggest selection failed.";
745 return original_click_indices;
746 }
747 if (installed_app_engine_ != nullptr &&
748 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
749 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
750 return original_click_indices;
751 }
752 if (number_annotator_ != nullptr &&
753 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
754 &candidates)) {
755 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
756 return original_click_indices;
757 }
758 if (duration_annotator_ != nullptr &&
759 !duration_annotator_->FindAll(context_unicode, tokens,
760 options.annotation_usecase, &candidates)) {
761 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
762 return original_click_indices;
763 }
764
765 // Sort candidates according to their position in the input, so that the next
766 // code can assume that any connected component of overlapping spans forms a
767 // contiguous block.
768 std::sort(candidates.begin(), candidates.end(),
769 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
770 return a.span.first < b.span.first;
771 });
772
773 std::vector<int> candidate_indices;
774 if (!ResolveConflicts(candidates, context, tokens,
775 detected_text_language_tags, options.annotation_usecase,
776 &interpreter_manager, &candidate_indices)) {
777 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
778 return original_click_indices;
779 }
780
781 std::sort(candidate_indices.begin(), candidate_indices.end(),
782 [&candidates](int a, int b) {
783 return GetPriorityScore(candidates[a].classification) >
784 GetPriorityScore(candidates[b].classification);
785 });
786
787 for (const int i : candidate_indices) {
788 if (SpansOverlap(candidates[i].span, click_indices) &&
789 SpansOverlap(candidates[i].span, original_click_indices)) {
790 // Run model classification if not present but requested and there's a
791 // classification collection filter specified.
792 if (candidates[i].classification.empty() &&
793 model_->selection_options()->always_classify_suggested_selection() &&
794 !filtered_collections_selection_.empty()) {
795 if (!ModelClassifyText(context, detected_text_language_tags,
796 candidates[i].span, &interpreter_manager,
797 /*embedding_cache=*/nullptr,
798 &candidates[i].classification)) {
799 return original_click_indices;
800 }
801 }
802
803 // Ignore if span classification is filtered.
804 if (FilteredForSelection(candidates[i])) {
805 return original_click_indices;
806 }
807
808 return candidates[i].span;
809 }
810 }
811
812 return original_click_indices;
813 }
814
815 namespace {
816 // Helper function that returns the index of the first candidate that
817 // transitively does not overlap with the candidate on 'start_index'. If the end
818 // of 'candidates' is reached, it returns the index that points right behind the
819 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)820 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
821 int start_index) {
822 int first_non_overlapping = start_index + 1;
823 CodepointSpan conflicting_span = candidates[start_index].span;
824 while (
825 first_non_overlapping < candidates.size() &&
826 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
827 // Grow the span to include the current one.
828 conflicting_span.second = std::max(
829 conflicting_span.second, candidates[first_non_overlapping].span.second);
830
831 ++first_non_overlapping;
832 }
833 return first_non_overlapping;
834 }
835 } // namespace
836
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,AnnotationUsecase annotation_usecase,InterpreterManager * interpreter_manager,std::vector<int> * result) const837 bool Annotator::ResolveConflicts(
838 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
839 const std::vector<Token>& cached_tokens,
840 const std::vector<Locale>& detected_text_language_tags,
841 AnnotationUsecase annotation_usecase,
842 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
843 result->clear();
844 result->reserve(candidates.size());
845 for (int i = 0; i < candidates.size();) {
846 int first_non_overlapping =
847 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
848
849 const bool conflict_found = first_non_overlapping != (i + 1);
850 if (conflict_found) {
851 std::vector<int> candidate_indices;
852 if (!ResolveConflict(context, cached_tokens, candidates,
853 detected_text_language_tags, i,
854 first_non_overlapping, annotation_usecase,
855 interpreter_manager, &candidate_indices)) {
856 return false;
857 }
858 result->insert(result->end(), candidate_indices.begin(),
859 candidate_indices.end());
860 } else {
861 result->push_back(i);
862 }
863
864 // Skip over the whole conflicting group/go to next candidate.
865 i = first_non_overlapping;
866 }
867 return true;
868 }
869
870 namespace {
871 // Returns true, if the given two sources do conflict in given annotation
872 // usecase.
873 // - In SMART usecase, all sources do conflict, because there's only 1 possible
874 // annotation for a given span.
875 // - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
876 // and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)877 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
878 const AnnotatedSpan::Source source1,
879 const AnnotatedSpan::Source source2) {
880 uint32 source_mask =
881 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
882
883 switch (annotation_usecase) {
884 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
885 // In the SMART mode, all annotations conflict.
886 return true;
887
888 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
889 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
890 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
891 // hours" (duration).
892 if ((source_mask &
893 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
894 (source_mask &
895 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
896 return false;
897 }
898
899 // A KNOWLEDGE entity does not conflict with anything.
900 if ((source_mask &
901 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
902 return false;
903 }
904
905 // Entities from other sources can conflict.
906 return true;
907 }
908 }
909 } // namespace
910
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,AnnotationUsecase annotation_usecase,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const911 bool Annotator::ResolveConflict(
912 const std::string& context, const std::vector<Token>& cached_tokens,
913 const std::vector<AnnotatedSpan>& candidates,
914 const std::vector<Locale>& detected_text_language_tags, int start_index,
915 int end_index, AnnotationUsecase annotation_usecase,
916 InterpreterManager* interpreter_manager,
917 std::vector<int>* chosen_indices) const {
918 std::vector<int> conflicting_indices;
919 std::unordered_map<int, float> scores;
920 for (int i = start_index; i < end_index; ++i) {
921 conflicting_indices.push_back(i);
922 if (!candidates[i].classification.empty()) {
923 scores[i] = GetPriorityScore(candidates[i].classification);
924 continue;
925 }
926
927 // OPTIMIZATION: So that we don't have to classify all the ML model
928 // spans apriori, we wait until we get here, when they conflict with
929 // something and we need the actual classification scores. So if the
930 // candidate conflicts and comes from the model, we need to run a
931 // classification to determine its priority:
932 std::vector<ClassificationResult> classification;
933 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
934 candidates[i].span, interpreter_manager,
935 /*embedding_cache=*/nullptr, &classification)) {
936 return false;
937 }
938
939 if (!classification.empty()) {
940 scores[i] = GetPriorityScore(classification);
941 }
942 }
943
944 std::sort(conflicting_indices.begin(), conflicting_indices.end(),
945 [&scores](int i, int j) { return scores[i] > scores[j]; });
946
947 // Here we keep a set of indices that were chosen, per-source, to enable
948 // effective computation.
949 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
950 chosen_indices_for_source_map;
951
952 // Greedily place the candidates if they don't conflict with the already
953 // placed ones.
954 for (int i = 0; i < conflicting_indices.size(); ++i) {
955 const int considered_candidate = conflicting_indices[i];
956
957 // See if there is a conflict between the candidate and all already placed
958 // candidates.
959 bool conflict = false;
960 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
961 for (auto& source_set_pair : chosen_indices_for_source_map) {
962 if (source_set_pair.first == candidates[considered_candidate].source) {
963 chosen_indices_for_source_ptr = &source_set_pair.second;
964 }
965
966 if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
967 candidates[considered_candidate].source) &&
968 DoesCandidateConflict(considered_candidate, candidates,
969 source_set_pair.second)) {
970 conflict = true;
971 break;
972 }
973 }
974
975 // Skip the candidate if a conflict was found.
976 if (conflict) {
977 continue;
978 }
979
980 // If the set of indices for the current source doesn't exist yet,
981 // initialize it.
982 if (chosen_indices_for_source_ptr == nullptr) {
983 SortedIntSet new_set([&candidates](int a, int b) {
984 return candidates[a].span.first < candidates[b].span.first;
985 });
986 chosen_indices_for_source_map[candidates[considered_candidate].source] =
987 std::move(new_set);
988 chosen_indices_for_source_ptr =
989 &chosen_indices_for_source_map[candidates[considered_candidate]
990 .source];
991 }
992
993 // Place the candidate to the output and to the per-source conflict set.
994 chosen_indices->push_back(considered_candidate);
995 chosen_indices_for_source_ptr->insert(considered_candidate);
996 }
997
998 std::sort(chosen_indices->begin(), chosen_indices->end());
999
1000 return true;
1001 }
1002
ModelSuggestSelection(const UnicodeText & context_unicode,CodepointSpan click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1003 bool Annotator::ModelSuggestSelection(
1004 const UnicodeText& context_unicode, CodepointSpan click_indices,
1005 const std::vector<Locale>& detected_text_language_tags,
1006 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1007 std::vector<AnnotatedSpan>* result) const {
1008 if (model_->triggering_options() == nullptr ||
1009 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1010 return true;
1011 }
1012
1013 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1014 ml_model_triggering_locales_,
1015 /*default_value=*/true)) {
1016 return true;
1017 }
1018
1019 int click_pos;
1020 *tokens = selection_feature_processor_->Tokenize(context_unicode);
1021 selection_feature_processor_->RetokenizeAndFindClick(
1022 context_unicode, click_indices,
1023 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1024 tokens, &click_pos);
1025 if (click_pos == kInvalidIndex) {
1026 TC3_VLOG(1) << "Could not calculate the click position.";
1027 return false;
1028 }
1029
1030 const int symmetry_context_size =
1031 model_->selection_options()->symmetry_context_size();
1032 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1033 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1034 ->bounds_sensitive_features();
1035
1036 // The symmetry context span is the clicked token with symmetry_context_size
1037 // tokens on either side.
1038 const TokenSpan symmetry_context_span = IntersectTokenSpans(
1039 ExpandTokenSpan(SingleTokenSpan(click_pos),
1040 /*num_tokens_left=*/symmetry_context_size,
1041 /*num_tokens_right=*/symmetry_context_size),
1042 {0, tokens->size()});
1043
1044 // Compute the extraction span based on the model type.
1045 TokenSpan extraction_span;
1046 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1047 // The extraction span is the symmetry context span expanded to include
1048 // max_selection_span tokens on either side, which is how far a selection
1049 // can stretch from the click, plus a relevant number of tokens outside of
1050 // the bounds of the selection.
1051 const int max_selection_span =
1052 selection_feature_processor_->GetOptions()->max_selection_span();
1053 extraction_span =
1054 ExpandTokenSpan(symmetry_context_span,
1055 /*num_tokens_left=*/max_selection_span +
1056 bounds_sensitive_features->num_tokens_before(),
1057 /*num_tokens_right=*/max_selection_span +
1058 bounds_sensitive_features->num_tokens_after());
1059 } else {
1060 // The extraction span is the symmetry context span expanded to include
1061 // context_size tokens on either side.
1062 const int context_size =
1063 selection_feature_processor_->GetOptions()->context_size();
1064 extraction_span = ExpandTokenSpan(symmetry_context_span,
1065 /*num_tokens_left=*/context_size,
1066 /*num_tokens_right=*/context_size);
1067 }
1068 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
1069
1070 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1071 *tokens, extraction_span)) {
1072 return true;
1073 }
1074
1075 std::unique_ptr<CachedFeatures> cached_features;
1076 if (!selection_feature_processor_->ExtractFeatures(
1077 *tokens, extraction_span,
1078 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1079 embedding_executor_.get(),
1080 /*embedding_cache=*/nullptr,
1081 selection_feature_processor_->EmbeddingSize() +
1082 selection_feature_processor_->DenseFeaturesCount(),
1083 &cached_features)) {
1084 TC3_LOG(ERROR) << "Could not extract features.";
1085 return false;
1086 }
1087
1088 // Produce selection model candidates.
1089 std::vector<TokenSpan> chunks;
1090 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1091 interpreter_manager->SelectionInterpreter(), *cached_features,
1092 &chunks)) {
1093 TC3_LOG(ERROR) << "Could not chunk.";
1094 return false;
1095 }
1096
1097 for (const TokenSpan& chunk : chunks) {
1098 AnnotatedSpan candidate;
1099 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1100 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1101 if (model_->selection_options()->strip_unpaired_brackets()) {
1102 candidate.span =
1103 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1104 }
1105
1106 // Only output non-empty spans.
1107 if (candidate.span.first != candidate.span.second) {
1108 result->push_back(candidate);
1109 }
1110 }
1111 return true;
1112 }
1113
ModelClassifyText(const std::string & context,const std::vector<Locale> & detected_text_language_tags,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const1114 bool Annotator::ModelClassifyText(
1115 const std::string& context,
1116 const std::vector<Locale>& detected_text_language_tags,
1117 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1118 FeatureProcessor::EmbeddingCache* embedding_cache,
1119 std::vector<ClassificationResult>* classification_results) const {
1120 return ModelClassifyText(context, {}, detected_text_language_tags,
1121 selection_indices, interpreter_manager,
1122 embedding_cache, classification_results);
1123 }
1124
1125 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,CodepointSpan selection_indices,TokenSpan tokens_around_selection_to_copy)1126 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1127 CodepointSpan selection_indices,
1128 TokenSpan tokens_around_selection_to_copy) {
1129 const auto first_selection_token = std::upper_bound(
1130 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1131 [](int selection_start, const Token& token) {
1132 return selection_start < token.end;
1133 });
1134 const auto last_selection_token = std::lower_bound(
1135 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1136 [](const Token& token, int selection_end) {
1137 return token.start < selection_end;
1138 });
1139
1140 const int64 first_token = std::max(
1141 static_cast<int64>(0),
1142 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1143 tokens_around_selection_to_copy.first));
1144 const int64 last_token = std::min(
1145 static_cast<int64>(cached_tokens.size()),
1146 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1147 tokens_around_selection_to_copy.second));
1148
1149 std::vector<Token> tokens;
1150 tokens.reserve(last_token - first_token);
1151 for (int i = first_token; i < last_token; ++i) {
1152 tokens.push_back(cached_tokens[i]);
1153 }
1154 return tokens;
1155 }
1156 } // namespace internal
1157
ClassifyTextUpperBoundNeededTokens() const1158 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1159 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1160 bounds_sensitive_features =
1161 classification_feature_processor_->GetOptions()
1162 ->bounds_sensitive_features();
1163 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1164 // The extraction span is the selection span expanded to include a relevant
1165 // number of tokens outside of the bounds of the selection.
1166 return {bounds_sensitive_features->num_tokens_before(),
1167 bounds_sensitive_features->num_tokens_after()};
1168 } else {
1169 // The extraction span is the clicked token with context_size tokens on
1170 // either side.
1171 const int context_size =
1172 selection_feature_processor_->GetOptions()->context_size();
1173 return {context_size, context_size};
1174 }
1175 }
1176
1177 namespace {
1178 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1179 void SortClassificationResults(
1180 std::vector<ClassificationResult>* classification_results) {
1181 std::sort(classification_results->begin(), classification_results->end(),
1182 [](const ClassificationResult& a, const ClassificationResult& b) {
1183 return a.score > b.score;
1184 });
1185 }
1186 } // namespace
1187
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const1188 bool Annotator::ModelClassifyText(
1189 const std::string& context, const std::vector<Token>& cached_tokens,
1190 const std::vector<Locale>& detected_text_language_tags,
1191 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1192 FeatureProcessor::EmbeddingCache* embedding_cache,
1193 std::vector<ClassificationResult>* classification_results) const {
1194 std::vector<Token> tokens;
1195 return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1196 selection_indices, interpreter_manager,
1197 embedding_cache, classification_results, &tokens);
1198 }
1199
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1200 bool Annotator::ModelClassifyText(
1201 const std::string& context, const std::vector<Token>& cached_tokens,
1202 const std::vector<Locale>& detected_text_language_tags,
1203 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1204 FeatureProcessor::EmbeddingCache* embedding_cache,
1205 std::vector<ClassificationResult>* classification_results,
1206 std::vector<Token>* tokens) const {
1207 if (model_->triggering_options() == nullptr ||
1208 !(model_->triggering_options()->enabled_modes() &
1209 ModeFlag_CLASSIFICATION)) {
1210 return true;
1211 }
1212
1213 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1214 ml_model_triggering_locales_,
1215 /*default_value=*/true)) {
1216 return true;
1217 }
1218
1219 if (cached_tokens.empty()) {
1220 *tokens = classification_feature_processor_->Tokenize(context);
1221 } else {
1222 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1223 ClassifyTextUpperBoundNeededTokens());
1224 }
1225
1226 int click_pos;
1227 classification_feature_processor_->RetokenizeAndFindClick(
1228 context, selection_indices,
1229 classification_feature_processor_->GetOptions()
1230 ->only_use_line_with_click(),
1231 tokens, &click_pos);
1232 const TokenSpan selection_token_span =
1233 CodepointSpanToTokenSpan(*tokens, selection_indices);
1234 const int selection_num_tokens = TokenSpanSize(selection_token_span);
1235 if (model_->classification_options()->max_num_tokens() > 0 &&
1236 model_->classification_options()->max_num_tokens() <
1237 selection_num_tokens) {
1238 *classification_results = {{Collections::Other(), 1.0}};
1239 return true;
1240 }
1241
1242 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1243 bounds_sensitive_features =
1244 classification_feature_processor_->GetOptions()
1245 ->bounds_sensitive_features();
1246 if (selection_token_span.first == kInvalidIndex ||
1247 selection_token_span.second == kInvalidIndex) {
1248 TC3_LOG(ERROR) << "Could not determine span.";
1249 return false;
1250 }
1251
1252 // Compute the extraction span based on the model type.
1253 TokenSpan extraction_span;
1254 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1255 // The extraction span is the selection span expanded to include a relevant
1256 // number of tokens outside of the bounds of the selection.
1257 extraction_span = ExpandTokenSpan(
1258 selection_token_span,
1259 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1260 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1261 } else {
1262 if (click_pos == kInvalidIndex) {
1263 TC3_LOG(ERROR) << "Couldn't choose a click position.";
1264 return false;
1265 }
1266 // The extraction span is the clicked token with context_size tokens on
1267 // either side.
1268 const int context_size =
1269 classification_feature_processor_->GetOptions()->context_size();
1270 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
1271 /*num_tokens_left=*/context_size,
1272 /*num_tokens_right=*/context_size);
1273 }
1274 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
1275
1276 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1277 *tokens, extraction_span)) {
1278 *classification_results = {{Collections::Other(), 1.0}};
1279 return true;
1280 }
1281
1282 std::unique_ptr<CachedFeatures> cached_features;
1283 if (!classification_feature_processor_->ExtractFeatures(
1284 *tokens, extraction_span, selection_indices,
1285 embedding_executor_.get(), embedding_cache,
1286 classification_feature_processor_->EmbeddingSize() +
1287 classification_feature_processor_->DenseFeaturesCount(),
1288 &cached_features)) {
1289 TC3_LOG(ERROR) << "Could not extract features.";
1290 return false;
1291 }
1292
1293 std::vector<float> features;
1294 features.reserve(cached_features->OutputFeaturesSize());
1295 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1296 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1297 &features);
1298 } else {
1299 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1300 }
1301
1302 TensorView<float> logits = classification_executor_->ComputeLogits(
1303 TensorView<float>(features.data(),
1304 {1, static_cast<int>(features.size())}),
1305 interpreter_manager->ClassificationInterpreter());
1306 if (!logits.is_valid()) {
1307 TC3_LOG(ERROR) << "Couldn't compute logits.";
1308 return false;
1309 }
1310
1311 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1312 logits.dim(1) != classification_feature_processor_->NumCollections()) {
1313 TC3_LOG(ERROR) << "Mismatching output";
1314 return false;
1315 }
1316
1317 const std::vector<float> scores =
1318 ComputeSoftmax(logits.data(), logits.dim(1));
1319
1320 if (scores.empty()) {
1321 *classification_results = {{Collections::Other(), 1.0}};
1322 return true;
1323 }
1324
1325 const int best_score_index =
1326 std::max_element(scores.begin(), scores.end()) - scores.begin();
1327 const std::string top_collection =
1328 classification_feature_processor_->LabelToCollection(best_score_index);
1329
1330 // Sanity checks.
1331 if (top_collection == Collections::Phone()) {
1332 const int digit_count = CountDigits(context, selection_indices);
1333 if (digit_count <
1334 model_->classification_options()->phone_min_num_digits() ||
1335 digit_count >
1336 model_->classification_options()->phone_max_num_digits()) {
1337 *classification_results = {{Collections::Other(), 1.0}};
1338 return true;
1339 }
1340 } else if (top_collection == Collections::Address()) {
1341 if (selection_num_tokens <
1342 model_->classification_options()->address_min_num_tokens()) {
1343 *classification_results = {{Collections::Other(), 1.0}};
1344 return true;
1345 }
1346 } else if (top_collection == Collections::Dictionary()) {
1347 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1348 dictionary_locales_,
1349 /*default_value=*/false)) {
1350 *classification_results = {{Collections::Other(), 1.0}};
1351 return true;
1352 }
1353 }
1354
1355 *classification_results = {{top_collection, 1.0, scores[best_score_index]}};
1356 return true;
1357 }
1358
RegexClassifyText(const std::string & context,CodepointSpan selection_indices,std::vector<ClassificationResult> * classification_result) const1359 bool Annotator::RegexClassifyText(
1360 const std::string& context, CodepointSpan selection_indices,
1361 std::vector<ClassificationResult>* classification_result) const {
1362 const std::string selection_text =
1363 UTF8ToUnicodeText(context, /*do_copy=*/false)
1364 .UTF8Substring(selection_indices.first, selection_indices.second);
1365 const UnicodeText selection_text_unicode(
1366 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1367
1368 // Check whether any of the regular expressions match.
1369 for (const int pattern_id : classification_regex_patterns_) {
1370 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1371 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1372 regex_pattern.pattern->Matcher(selection_text_unicode);
1373 int status = UniLib::RegexMatcher::kNoError;
1374 bool matches;
1375 if (regex_pattern.config->use_approximate_matching()) {
1376 matches = matcher->ApproximatelyMatches(&status);
1377 } else {
1378 matches = matcher->Matches(&status);
1379 }
1380 if (status != UniLib::RegexMatcher::kNoError) {
1381 return false;
1382 }
1383 if (matches && VerifyRegexMatchCandidate(
1384 context, regex_pattern.config->verification_options(),
1385 selection_text, matcher.get())) {
1386 classification_result->push_back(
1387 {regex_pattern.config->collection_name()->str(),
1388 regex_pattern.config->target_classification_score(),
1389 regex_pattern.config->priority_score()});
1390 if (!SerializedEntityDataFromRegexMatch(
1391 regex_pattern.config, matcher.get(),
1392 &classification_result->back().serialized_entity_data)) {
1393 TC3_LOG(ERROR) << "Could not get entity data.";
1394 return false;
1395 }
1396 }
1397 }
1398
1399 return true;
1400 }
1401
1402 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1403 std::string PickCollectionForDatetime(
1404 const DatetimeParseResult& datetime_parse_result) {
1405 switch (datetime_parse_result.granularity) {
1406 case GRANULARITY_HOUR:
1407 case GRANULARITY_MINUTE:
1408 case GRANULARITY_SECOND:
1409 return Collections::DateTime();
1410 default:
1411 return Collections::Date();
1412 }
1413 }
1414
CreateDatetimeSerializedEntityData(const DatetimeParseResult & parse_result)1415 std::string CreateDatetimeSerializedEntityData(
1416 const DatetimeParseResult& parse_result) {
1417 EntityDataT entity_data;
1418 entity_data.datetime.reset(new EntityData_::DatetimeT());
1419 entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
1420 entity_data.datetime->granularity =
1421 static_cast<EntityData_::Datetime_::Granularity>(
1422 parse_result.granularity);
1423
1424 flatbuffers::FlatBufferBuilder builder;
1425 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
1426 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
1427 builder.GetSize());
1428 }
1429 } // namespace
1430
DatetimeClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1431 bool Annotator::DatetimeClassifyText(
1432 const std::string& context, CodepointSpan selection_indices,
1433 const ClassificationOptions& options,
1434 std::vector<ClassificationResult>* classification_results) const {
1435 if (!datetime_parser_) {
1436 return false;
1437 }
1438
1439 const std::string selection_text =
1440 UTF8ToUnicodeText(context, /*do_copy=*/false)
1441 .UTF8Substring(selection_indices.first, selection_indices.second);
1442
1443 std::vector<DatetimeParseResultSpan> datetime_spans;
1444 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1445 options.reference_timezone, options.locales,
1446 ModeFlag_CLASSIFICATION,
1447 options.annotation_usecase,
1448 /*anchor_start_end=*/true, &datetime_spans)) {
1449 TC3_LOG(ERROR) << "Error during parsing datetime.";
1450 return false;
1451 }
1452 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1453 // Only consider the result valid if the selection and extracted datetime
1454 // spans exactly match.
1455 if (std::make_pair(datetime_span.span.first + selection_indices.first,
1456 datetime_span.span.second + selection_indices.first) ==
1457 selection_indices) {
1458 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1459 classification_results->emplace_back(
1460 PickCollectionForDatetime(parse_result),
1461 datetime_span.target_classification_score);
1462 classification_results->back().datetime_parse_result = parse_result;
1463 classification_results->back().serialized_entity_data =
1464 CreateDatetimeSerializedEntityData(parse_result);
1465 classification_results->back().priority_score =
1466 datetime_span.priority_score;
1467 }
1468 return true;
1469 }
1470 }
1471 return true;
1472 }
1473
ClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options) const1474 std::vector<ClassificationResult> Annotator::ClassifyText(
1475 const std::string& context, CodepointSpan selection_indices,
1476 const ClassificationOptions& options) const {
1477 if (!initialized_) {
1478 TC3_LOG(ERROR) << "Not initialized";
1479 return {};
1480 }
1481
1482 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1483 return {};
1484 }
1485
1486 std::vector<Locale> detected_text_language_tags;
1487 if (!ParseLocales(options.detected_text_language_tags,
1488 &detected_text_language_tags)) {
1489 TC3_LOG(WARNING)
1490 << "Failed to parse the detected_text_language_tags in options: "
1491 << options.detected_text_language_tags;
1492 }
1493 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1494 model_triggering_locales_,
1495 /*default_value=*/true)) {
1496 return {};
1497 }
1498
1499 if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
1500 return {};
1501 }
1502
1503 if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
1504 TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
1505 << std::get<0>(selection_indices) << " "
1506 << std::get<1>(selection_indices);
1507 return {};
1508 }
1509
1510 // We'll accumulate a list of candidates, and pick the best candidate in the
1511 // end.
1512 std::vector<AnnotatedSpan> candidates;
1513
1514 // Try the knowledge engine.
1515 // TODO(b/126579108): Propagate error status.
1516 ClassificationResult knowledge_result;
1517 if (knowledge_engine_ && knowledge_engine_->ClassifyText(
1518 context, selection_indices, &knowledge_result)) {
1519 candidates.push_back({selection_indices, {knowledge_result}});
1520 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1521 }
1522
1523 // Try the contact engine.
1524 // TODO(b/126579108): Propagate error status.
1525 ClassificationResult contact_result;
1526 if (contact_engine_ && contact_engine_->ClassifyText(
1527 context, selection_indices, &contact_result)) {
1528 candidates.push_back({selection_indices, {contact_result}});
1529 }
1530
1531 // Try the installed app engine.
1532 // TODO(b/126579108): Propagate error status.
1533 ClassificationResult installed_app_result;
1534 if (installed_app_engine_ &&
1535 installed_app_engine_->ClassifyText(context, selection_indices,
1536 &installed_app_result)) {
1537 candidates.push_back({selection_indices, {installed_app_result}});
1538 }
1539
1540 // Try the regular expression models.
1541 std::vector<ClassificationResult> regex_results;
1542 if (!RegexClassifyText(context, selection_indices, ®ex_results)) {
1543 return {};
1544 }
1545 for (const ClassificationResult& result : regex_results) {
1546 candidates.push_back({selection_indices, {result}});
1547 }
1548
1549 // Try the date model.
1550 //
1551 // DatetimeClassifyText only returns the first result, which can however have
1552 // more interpretations. They are inserted in the candidates as a single
1553 // AnnotatedSpan, so that they get treated together by the conflict resolution
1554 // algorithm.
1555 std::vector<ClassificationResult> datetime_results;
1556 if (!DatetimeClassifyText(context, selection_indices, options,
1557 &datetime_results)) {
1558 return {};
1559 }
1560 if (!datetime_results.empty()) {
1561 candidates.push_back({selection_indices, std::move(datetime_results)});
1562 candidates.back().source = AnnotatedSpan::Source::DATETIME;
1563 }
1564
1565 // Try the number annotator.
1566 // TODO(b/126579108): Propagate error status.
1567 ClassificationResult number_annotator_result;
1568 if (number_annotator_ &&
1569 number_annotator_->ClassifyText(
1570 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1571 options.annotation_usecase, &number_annotator_result)) {
1572 candidates.push_back({selection_indices, {number_annotator_result}});
1573 }
1574
1575 // Try the duration annotator.
1576 ClassificationResult duration_annotator_result;
1577 if (duration_annotator_ &&
1578 duration_annotator_->ClassifyText(
1579 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1580 options.annotation_usecase, &duration_annotator_result)) {
1581 candidates.push_back({selection_indices, {duration_annotator_result}});
1582 candidates.back().source = AnnotatedSpan::Source::DURATION;
1583 }
1584
1585 // Try the ML model.
1586 //
1587 // The output of the model is considered as an exclusive 1-of-N choice. That's
1588 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1589 // span for each candidate, like e.g. the regex model.
1590 InterpreterManager interpreter_manager(selection_executor_.get(),
1591 classification_executor_.get());
1592 std::vector<ClassificationResult> model_results;
1593 std::vector<Token> tokens;
1594 if (!ModelClassifyText(
1595 context, /*cached_tokens=*/{}, detected_text_language_tags,
1596 selection_indices, &interpreter_manager,
1597 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1598 return {};
1599 }
1600 if (!model_results.empty()) {
1601 candidates.push_back({selection_indices, std::move(model_results)});
1602 }
1603
1604 std::vector<int> candidate_indices;
1605 if (!ResolveConflicts(candidates, context, tokens,
1606 detected_text_language_tags, options.annotation_usecase,
1607 &interpreter_manager, &candidate_indices)) {
1608 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1609 return {};
1610 }
1611
1612 std::vector<ClassificationResult> results;
1613 for (const int i : candidate_indices) {
1614 for (const ClassificationResult& result : candidates[i].classification) {
1615 if (!FilteredForClassification(result)) {
1616 results.push_back(result);
1617 }
1618 }
1619 }
1620
1621 // Sort results according to score.
1622 std::sort(results.begin(), results.end(),
1623 [](const ClassificationResult& a, const ClassificationResult& b) {
1624 return a.score > b.score;
1625 });
1626
1627 if (results.empty()) {
1628 results = {{Collections::Other(), 1.0}};
1629 }
1630 return results;
1631 }
1632
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1633 bool Annotator::ModelAnnotate(
1634 const std::string& context,
1635 const std::vector<Locale>& detected_text_language_tags,
1636 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1637 std::vector<AnnotatedSpan>* result) const {
1638 if (model_->triggering_options() == nullptr ||
1639 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1640 return true;
1641 }
1642
1643 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1644 ml_model_triggering_locales_,
1645 /*default_value=*/true)) {
1646 return true;
1647 }
1648
1649 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1650 /*do_copy=*/false);
1651 std::vector<UnicodeTextRange> lines;
1652 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1653 lines.push_back({context_unicode.begin(), context_unicode.end()});
1654 } else {
1655 lines = selection_feature_processor_->SplitContext(context_unicode);
1656 }
1657
1658 const float min_annotate_confidence =
1659 (model_->triggering_options() != nullptr
1660 ? model_->triggering_options()->min_annotate_confidence()
1661 : 0.f);
1662
1663 for (const UnicodeTextRange& line : lines) {
1664 FeatureProcessor::EmbeddingCache embedding_cache;
1665 const std::string line_str =
1666 UnicodeText::UTF8Substring(line.first, line.second);
1667
1668 *tokens = selection_feature_processor_->Tokenize(line_str);
1669 selection_feature_processor_->RetokenizeAndFindClick(
1670 line_str, {0, std::distance(line.first, line.second)},
1671 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1672 tokens,
1673 /*click_pos=*/nullptr);
1674 const TokenSpan full_line_span = {0, tokens->size()};
1675
1676 // TODO(zilka): Add support for greater granularity of this check.
1677 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1678 *tokens, full_line_span)) {
1679 continue;
1680 }
1681
1682 std::unique_ptr<CachedFeatures> cached_features;
1683 if (!selection_feature_processor_->ExtractFeatures(
1684 *tokens, full_line_span,
1685 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1686 embedding_executor_.get(),
1687 /*embedding_cache=*/nullptr,
1688 selection_feature_processor_->EmbeddingSize() +
1689 selection_feature_processor_->DenseFeaturesCount(),
1690 &cached_features)) {
1691 TC3_LOG(ERROR) << "Could not extract features.";
1692 return false;
1693 }
1694
1695 std::vector<TokenSpan> local_chunks;
1696 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
1697 interpreter_manager->SelectionInterpreter(),
1698 *cached_features, &local_chunks)) {
1699 TC3_LOG(ERROR) << "Could not chunk.";
1700 return false;
1701 }
1702
1703 const int offset = std::distance(context_unicode.begin(), line.first);
1704 for (const TokenSpan& chunk : local_chunks) {
1705 const CodepointSpan codepoint_span =
1706 selection_feature_processor_->StripBoundaryCodepoints(
1707 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
1708
1709 // Skip empty spans.
1710 if (codepoint_span.first != codepoint_span.second) {
1711 std::vector<ClassificationResult> classification;
1712 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
1713 codepoint_span, interpreter_manager,
1714 &embedding_cache, &classification)) {
1715 TC3_LOG(ERROR) << "Could not classify text: "
1716 << (codepoint_span.first + offset) << " "
1717 << (codepoint_span.second + offset);
1718 return false;
1719 }
1720
1721 // Do not include the span if it's classified as "other".
1722 if (!classification.empty() && !ClassifiedAsOther(classification) &&
1723 classification[0].score >= min_annotate_confidence) {
1724 AnnotatedSpan result_span;
1725 result_span.span = {codepoint_span.first + offset,
1726 codepoint_span.second + offset};
1727 result_span.classification = std::move(classification);
1728 result->push_back(std::move(result_span));
1729 }
1730 }
1731 }
1732 }
1733 return true;
1734 }
1735
SelectionFeatureProcessorForTests() const1736 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
1737 return selection_feature_processor_.get();
1738 }
1739
ClassificationFeatureProcessorForTests() const1740 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
1741 const {
1742 return classification_feature_processor_.get();
1743 }
1744
DatetimeParserForTests() const1745 const DatetimeParser* Annotator::DatetimeParserForTests() const {
1746 return datetime_parser_.get();
1747 }
1748
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const1749 void Annotator::RemoveNotEnabledEntityTypes(
1750 const EnabledEntityTypes& is_entity_type_enabled,
1751 std::vector<AnnotatedSpan>* annotated_spans) const {
1752 for (AnnotatedSpan& annotated_span : *annotated_spans) {
1753 std::vector<ClassificationResult>& classifications =
1754 annotated_span.classification;
1755 classifications.erase(
1756 std::remove_if(classifications.begin(), classifications.end(),
1757 [&is_entity_type_enabled](
1758 const ClassificationResult& classification_result) {
1759 return !is_entity_type_enabled(
1760 classification_result.collection);
1761 }),
1762 classifications.end());
1763 }
1764 annotated_spans->erase(
1765 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
1766 [](const AnnotatedSpan& annotated_span) {
1767 return annotated_span.classification.empty();
1768 }),
1769 annotated_spans->end());
1770 }
1771
Annotate(const std::string & context,const AnnotationOptions & options) const1772 std::vector<AnnotatedSpan> Annotator::Annotate(
1773 const std::string& context, const AnnotationOptions& options) const {
1774 std::vector<AnnotatedSpan> candidates;
1775
1776 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
1777 return {};
1778 }
1779
1780 const UnicodeText context_unicode =
1781 UTF8ToUnicodeText(context, /*do_copy=*/false);
1782 if (!context_unicode.is_valid()) {
1783 return {};
1784 }
1785
1786 std::vector<Locale> detected_text_language_tags;
1787 if (!ParseLocales(options.detected_text_language_tags,
1788 &detected_text_language_tags)) {
1789 TC3_LOG(WARNING)
1790 << "Failed to parse the detected_text_language_tags in options: "
1791 << options.detected_text_language_tags;
1792 }
1793 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1794 model_triggering_locales_,
1795 /*default_value=*/true)) {
1796 return {};
1797 }
1798
1799 InterpreterManager interpreter_manager(selection_executor_.get(),
1800 classification_executor_.get());
1801
1802 // Annotate with the selection model.
1803 std::vector<Token> tokens;
1804 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
1805 &tokens, &candidates)) {
1806 TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
1807 return {};
1808 }
1809
1810 // Annotate with the regular expression models.
1811 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1812 annotation_regex_patterns_, &candidates,
1813 options.is_serialized_entity_data_enabled)) {
1814 TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
1815 return {};
1816 }
1817
1818 // Annotate with the datetime model.
1819 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
1820 if ((is_entity_type_enabled(Collections::Date()) ||
1821 is_entity_type_enabled(Collections::DateTime())) &&
1822 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1823 options.reference_time_ms_utc, options.reference_timezone,
1824 options.locales, ModeFlag_ANNOTATION,
1825 options.annotation_usecase,
1826 options.is_serialized_entity_data_enabled, &candidates)) {
1827 TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
1828 return {};
1829 }
1830
1831 // Annotate with the knowledge engine.
1832 if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
1833 TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
1834 return {};
1835 }
1836
1837 // Annotate with the contact engine.
1838 if (contact_engine_ &&
1839 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
1840 TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
1841 return {};
1842 }
1843
1844 // Annotate with the installed app engine.
1845 if (installed_app_engine_ &&
1846 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
1847 TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
1848 return {};
1849 }
1850
1851 // Annotate with the number annotator.
1852 if (number_annotator_ != nullptr &&
1853 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
1854 &candidates)) {
1855 TC3_LOG(ERROR) << "Couldn't run number annotator FindAll.";
1856 return {};
1857 }
1858
1859 // Annotate with the duration annotator.
1860 if (is_entity_type_enabled(Collections::Duration()) &&
1861 duration_annotator_ != nullptr &&
1862 !duration_annotator_->FindAll(context_unicode, tokens,
1863 options.annotation_usecase, &candidates)) {
1864 TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll.";
1865 return {};
1866 }
1867
1868 // Sort candidates according to their position in the input, so that the next
1869 // code can assume that any connected component of overlapping spans forms a
1870 // contiguous block.
1871 std::sort(candidates.begin(), candidates.end(),
1872 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
1873 return a.span.first < b.span.first;
1874 });
1875
1876 std::vector<int> candidate_indices;
1877 if (!ResolveConflicts(candidates, context, tokens,
1878 detected_text_language_tags, options.annotation_usecase,
1879 &interpreter_manager, &candidate_indices)) {
1880 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1881 return {};
1882 }
1883
1884 std::vector<AnnotatedSpan> result;
1885 result.reserve(candidate_indices.size());
1886 AnnotatedSpan aggregated_span;
1887 for (const int i : candidate_indices) {
1888 if (candidates[i].span != aggregated_span.span) {
1889 if (!aggregated_span.classification.empty()) {
1890 result.push_back(std::move(aggregated_span));
1891 }
1892 aggregated_span =
1893 AnnotatedSpan(candidates[i].span, /*arg_classification=*/{});
1894 }
1895 if (candidates[i].classification.empty() ||
1896 ClassifiedAsOther(candidates[i].classification) ||
1897 FilteredForAnnotation(candidates[i])) {
1898 continue;
1899 }
1900 for (ClassificationResult& classification : candidates[i].classification) {
1901 aggregated_span.classification.push_back(std::move(classification));
1902 }
1903 }
1904 if (!aggregated_span.classification.empty()) {
1905 result.push_back(std::move(aggregated_span));
1906 }
1907
1908 // We generate all candidates and remove them later (with the exception of
1909 // date/time/duration entities) because there are complex interdependencies
1910 // between the entity types. E.g., the TLD of an email can be interpreted as a
1911 // URL, but most likely a user of the API does not want such annotations if
1912 // "url" is enabled and "email" is not.
1913 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
1914
1915 for (AnnotatedSpan& annotated_span : result) {
1916 SortClassificationResults(&annotated_span.classification);
1917 }
1918
1919 return result;
1920 }
1921
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const1922 CodepointSpan Annotator::ComputeSelectionBoundaries(
1923 const UniLib::RegexMatcher* match,
1924 const RegexModel_::Pattern* config) const {
1925 if (config->capturing_group() == nullptr) {
1926 // Use first capturing group to specify the selection.
1927 int status = UniLib::RegexMatcher::kNoError;
1928 const CodepointSpan result = {match->Start(1, &status),
1929 match->End(1, &status)};
1930 if (status != UniLib::RegexMatcher::kNoError) {
1931 return {kInvalidIndex, kInvalidIndex};
1932 }
1933 return result;
1934 }
1935
1936 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
1937 const int num_groups = config->capturing_group()->size();
1938 for (int i = 0; i < num_groups; i++) {
1939 if (!config->capturing_group()->Get(i)->extend_selection()) {
1940 continue;
1941 }
1942
1943 int status = UniLib::RegexMatcher::kNoError;
1944 // Check match and adjust bounds.
1945 const int group_start = match->Start(i, &status);
1946 const int group_end = match->End(i, &status);
1947 if (status != UniLib::RegexMatcher::kNoError) {
1948 return {kInvalidIndex, kInvalidIndex};
1949 }
1950 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
1951 continue;
1952 }
1953 if (result.first == kInvalidIndex) {
1954 result = {group_start, group_end};
1955 } else {
1956 result.first = std::min(result.first, group_start);
1957 result.second = std::max(result.second, group_end);
1958 }
1959 }
1960 return result;
1961 }
1962
HasEntityData(const RegexModel_::Pattern * pattern) const1963 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
1964 if (pattern->serialized_entity_data() != nullptr) {
1965 return true;
1966 }
1967 if (pattern->capturing_group() != nullptr) {
1968 for (const RegexModel_::Pattern_::CapturingGroup* group :
1969 *pattern->capturing_group()) {
1970 if (group->entity_field_path() != nullptr) {
1971 return true;
1972 }
1973 }
1974 }
1975 return false;
1976 }
1977
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const1978 bool Annotator::SerializedEntityDataFromRegexMatch(
1979 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
1980 std::string* serialized_entity_data) const {
1981 if (!HasEntityData(pattern)) {
1982 serialized_entity_data->clear();
1983 return true;
1984 }
1985 TC3_CHECK(entity_data_builder_ != nullptr);
1986
1987 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1988 entity_data_builder_->NewRoot();
1989
1990 TC3_CHECK(entity_data != nullptr);
1991
1992 // Set static entity data.
1993 if (pattern->serialized_entity_data() != nullptr) {
1994 TC3_CHECK(entity_data != nullptr);
1995 entity_data->MergeFromSerializedFlatbuffer(
1996 StringPiece(pattern->serialized_entity_data()->c_str(),
1997 pattern->serialized_entity_data()->size()));
1998 }
1999
2000 // Add entity data from rule capturing groups.
2001 if (pattern->capturing_group() != nullptr) {
2002 const int num_groups = pattern->capturing_group()->size();
2003 for (int i = 0; i < num_groups; i++) {
2004 const FlatbufferFieldPath* field_path =
2005 pattern->capturing_group()->Get(i)->entity_field_path();
2006 if (field_path == nullptr) {
2007 continue;
2008 }
2009 TC3_CHECK(entity_data != nullptr);
2010 if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
2011 entity_data.get())) {
2012 TC3_LOG(ERROR)
2013 << "Could not set entity data from rule capturing group.";
2014 return false;
2015 }
2016 }
2017 }
2018
2019 *serialized_entity_data = entity_data->Serialize();
2020 return true;
2021 }
2022
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,std::vector<AnnotatedSpan> * result,bool is_serialized_entity_data_enabled) const2023 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2024 const std::vector<int>& rules,
2025 std::vector<AnnotatedSpan>* result,
2026 bool is_serialized_entity_data_enabled) const {
2027 for (int pattern_id : rules) {
2028 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2029 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2030 if (!matcher) {
2031 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2032 << pattern_id;
2033 return false;
2034 }
2035
2036 int status = UniLib::RegexMatcher::kNoError;
2037 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2038 if (regex_pattern.config->verification_options()) {
2039 if (!VerifyRegexMatchCandidate(
2040 context_unicode.ToUTF8String(),
2041 regex_pattern.config->verification_options(),
2042 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2043 continue;
2044 }
2045 }
2046
2047 std::string serialized_entity_data;
2048 if (is_serialized_entity_data_enabled) {
2049 if (!SerializedEntityDataFromRegexMatch(
2050 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2051 TC3_LOG(ERROR) << "Could not get entity data.";
2052 return false;
2053 }
2054 }
2055
2056 result->emplace_back();
2057
2058 // Selection/annotation regular expressions need to specify a capturing
2059 // group specifying the selection.
2060 result->back().span =
2061 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2062
2063 result->back().classification = {
2064 {regex_pattern.config->collection_name()->str(),
2065 regex_pattern.config->target_classification_score(),
2066 regex_pattern.config->priority_score()}};
2067
2068 result->back().classification[0].serialized_entity_data =
2069 serialized_entity_data;
2070 }
2071 }
2072 return true;
2073 }
2074
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2075 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2076 tflite::Interpreter* selection_interpreter,
2077 const CachedFeatures& cached_features,
2078 std::vector<TokenSpan>* chunks) const {
2079 const int max_selection_span =
2080 selection_feature_processor_->GetOptions()->max_selection_span();
2081 // The inference span is the span of interest expanded to include
2082 // max_selection_span tokens on either side, which is how far a selection can
2083 // stretch from the click.
2084 const TokenSpan inference_span = IntersectTokenSpans(
2085 ExpandTokenSpan(span_of_interest,
2086 /*num_tokens_left=*/max_selection_span,
2087 /*num_tokens_right=*/max_selection_span),
2088 {0, num_tokens});
2089
2090 std::vector<ScoredChunk> scored_chunks;
2091 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2092 selection_feature_processor_->GetOptions()
2093 ->bounds_sensitive_features()
2094 ->enabled()) {
2095 if (!ModelBoundsSensitiveScoreChunks(
2096 num_tokens, span_of_interest, inference_span, cached_features,
2097 selection_interpreter, &scored_chunks)) {
2098 return false;
2099 }
2100 } else {
2101 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2102 cached_features, selection_interpreter,
2103 &scored_chunks)) {
2104 return false;
2105 }
2106 }
2107 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2108 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2109 return lhs.score < rhs.score;
2110 });
2111
2112 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2113 // them greedily as long as they do not overlap with any previously picked
2114 // chunks.
2115 std::vector<bool> token_used(TokenSpanSize(inference_span));
2116 chunks->clear();
2117 for (const ScoredChunk& scored_chunk : scored_chunks) {
2118 bool feasible = true;
2119 for (int i = scored_chunk.token_span.first;
2120 i < scored_chunk.token_span.second; ++i) {
2121 if (token_used[i - inference_span.first]) {
2122 feasible = false;
2123 break;
2124 }
2125 }
2126
2127 if (!feasible) {
2128 continue;
2129 }
2130
2131 for (int i = scored_chunk.token_span.first;
2132 i < scored_chunk.token_span.second; ++i) {
2133 token_used[i - inference_span.first] = true;
2134 }
2135
2136 chunks->push_back(scored_chunk.token_span);
2137 }
2138
2139 std::sort(chunks->begin(), chunks->end());
2140
2141 return true;
2142 }
2143
2144 namespace {
2145 // Updates the value at the given key in the map to maximum of the current value
2146 // and the given value, or simply inserts the value if the key is not yet there.
2147 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2148 void UpdateMax(Map* map, typename Map::key_type key,
2149 typename Map::mapped_type value) {
2150 const auto it = map->find(key);
2151 if (it != map->end()) {
2152 it->second = std::max(it->second, value);
2153 } else {
2154 (*map)[key] = value;
2155 }
2156 }
2157 } // namespace
2158
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2159 bool Annotator::ModelClickContextScoreChunks(
2160 int num_tokens, const TokenSpan& span_of_interest,
2161 const CachedFeatures& cached_features,
2162 tflite::Interpreter* selection_interpreter,
2163 std::vector<ScoredChunk>* scored_chunks) const {
2164 const int max_batch_size = model_->selection_options()->batch_size();
2165
2166 std::vector<float> all_features;
2167 std::map<TokenSpan, float> chunk_scores;
2168 for (int batch_start = span_of_interest.first;
2169 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2170 const int batch_end =
2171 std::min(batch_start + max_batch_size, span_of_interest.second);
2172
2173 // Prepare features for the whole batch.
2174 all_features.clear();
2175 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2176 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2177 cached_features.AppendClickContextFeaturesForClick(click_pos,
2178 &all_features);
2179 }
2180
2181 // Run batched inference.
2182 const int batch_size = batch_end - batch_start;
2183 const int features_size = cached_features.OutputFeaturesSize();
2184 TensorView<float> logits = selection_executor_->ComputeLogits(
2185 TensorView<float>(all_features.data(), {batch_size, features_size}),
2186 selection_interpreter);
2187 if (!logits.is_valid()) {
2188 TC3_LOG(ERROR) << "Couldn't compute logits.";
2189 return false;
2190 }
2191 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2192 logits.dim(1) !=
2193 selection_feature_processor_->GetSelectionLabelCount()) {
2194 TC3_LOG(ERROR) << "Mismatching output.";
2195 return false;
2196 }
2197
2198 // Save results.
2199 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2200 const std::vector<float> scores = ComputeSoftmax(
2201 logits.data() + logits.dim(1) * (click_pos - batch_start),
2202 logits.dim(1));
2203 for (int j = 0;
2204 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2205 TokenSpan relative_token_span;
2206 if (!selection_feature_processor_->LabelToTokenSpan(
2207 j, &relative_token_span)) {
2208 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
2209 return false;
2210 }
2211 const TokenSpan candidate_span = ExpandTokenSpan(
2212 SingleTokenSpan(click_pos), relative_token_span.first,
2213 relative_token_span.second);
2214 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2215 UpdateMax(&chunk_scores, candidate_span, scores[j]);
2216 }
2217 }
2218 }
2219 }
2220
2221 scored_chunks->clear();
2222 scored_chunks->reserve(chunk_scores.size());
2223 for (const auto& entry : chunk_scores) {
2224 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2225 }
2226
2227 return true;
2228 }
2229
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2230 bool Annotator::ModelBoundsSensitiveScoreChunks(
2231 int num_tokens, const TokenSpan& span_of_interest,
2232 const TokenSpan& inference_span, const CachedFeatures& cached_features,
2233 tflite::Interpreter* selection_interpreter,
2234 std::vector<ScoredChunk>* scored_chunks) const {
2235 const int max_selection_span =
2236 selection_feature_processor_->GetOptions()->max_selection_span();
2237 const int max_chunk_length = selection_feature_processor_->GetOptions()
2238 ->selection_reduced_output_space()
2239 ? max_selection_span + 1
2240 : 2 * max_selection_span + 1;
2241 const bool score_single_token_spans_as_zero =
2242 selection_feature_processor_->GetOptions()
2243 ->bounds_sensitive_features()
2244 ->score_single_token_spans_as_zero();
2245
2246 scored_chunks->clear();
2247 if (score_single_token_spans_as_zero) {
2248 scored_chunks->reserve(TokenSpanSize(span_of_interest));
2249 }
2250
2251 // Prepare all chunk candidates into one batch:
2252 // - Are contained in the inference span
2253 // - Have a non-empty intersection with the span of interest
2254 // - Are at least one token long
2255 // - Are not longer than the maximum chunk length
2256 std::vector<TokenSpan> candidate_spans;
2257 for (int start = inference_span.first; start < span_of_interest.second;
2258 ++start) {
2259 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2260 for (int end = leftmost_end_index;
2261 end <= inference_span.second && end - start <= max_chunk_length;
2262 ++end) {
2263 const TokenSpan candidate_span = {start, end};
2264 if (score_single_token_spans_as_zero &&
2265 TokenSpanSize(candidate_span) == 1) {
2266 // Do not include the single token span in the batch, add a zero score
2267 // for it directly to the output.
2268 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2269 } else {
2270 candidate_spans.push_back(candidate_span);
2271 }
2272 }
2273 }
2274
2275 const int max_batch_size = model_->selection_options()->batch_size();
2276
2277 std::vector<float> all_features;
2278 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
2279 for (int batch_start = 0; batch_start < candidate_spans.size();
2280 batch_start += max_batch_size) {
2281 const int batch_end = std::min(batch_start + max_batch_size,
2282 static_cast<int>(candidate_spans.size()));
2283
2284 // Prepare features for the whole batch.
2285 all_features.clear();
2286 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2287 for (int i = batch_start; i < batch_end; ++i) {
2288 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2289 &all_features);
2290 }
2291
2292 // Run batched inference.
2293 const int batch_size = batch_end - batch_start;
2294 const int features_size = cached_features.OutputFeaturesSize();
2295 TensorView<float> logits = selection_executor_->ComputeLogits(
2296 TensorView<float>(all_features.data(), {batch_size, features_size}),
2297 selection_interpreter);
2298 if (!logits.is_valid()) {
2299 TC3_LOG(ERROR) << "Couldn't compute logits.";
2300 return false;
2301 }
2302 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2303 logits.dim(1) != 1) {
2304 TC3_LOG(ERROR) << "Mismatching output.";
2305 return false;
2306 }
2307
2308 // Save results.
2309 for (int i = batch_start; i < batch_end; ++i) {
2310 scored_chunks->push_back(
2311 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2312 }
2313 }
2314
2315 return true;
2316 }
2317
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const2318 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
2319 int64 reference_time_ms_utc,
2320 const std::string& reference_timezone,
2321 const std::string& locales, ModeFlag mode,
2322 AnnotationUsecase annotation_usecase,
2323 bool is_serialized_entity_data_enabled,
2324 std::vector<AnnotatedSpan>* result) const {
2325 if (!datetime_parser_) {
2326 return true;
2327 }
2328
2329 std::vector<DatetimeParseResultSpan> datetime_spans;
2330 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
2331 reference_timezone, locales, mode,
2332 annotation_usecase,
2333 /*anchor_start_end=*/false, &datetime_spans)) {
2334 return false;
2335 }
2336 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
2337 AnnotatedSpan annotated_span;
2338 annotated_span.span = datetime_span.span;
2339 for (const DatetimeParseResult& parse_result : datetime_span.data) {
2340 annotated_span.classification.emplace_back(
2341 PickCollectionForDatetime(parse_result),
2342 datetime_span.target_classification_score,
2343 datetime_span.priority_score);
2344 annotated_span.classification.back().datetime_parse_result = parse_result;
2345 if (is_serialized_entity_data_enabled) {
2346 annotated_span.classification.back().serialized_entity_data =
2347 CreateDatetimeSerializedEntityData(parse_result);
2348 }
2349 }
2350 annotated_span.source = AnnotatedSpan::Source::DATETIME;
2351 result->push_back(std::move(annotated_span));
2352 }
2353 return true;
2354 }
2355
model() const2356 const Model* Annotator::model() const { return model_; }
entity_data_schema() const2357 const reflection::Schema* Annotator::entity_data_schema() const {
2358 return entity_data_schema_;
2359 }
2360
ViewModel(const void * buffer,int size)2361 const Model* ViewModel(const void* buffer, int size) {
2362 if (!buffer) {
2363 return nullptr;
2364 }
2365
2366 return LoadAndVerifyModel(buffer, size);
2367 }
2368
LookUpKnowledgeEntity(const std::string & id,std::string * serialized_knowledge_result) const2369 bool Annotator::LookUpKnowledgeEntity(
2370 const std::string& id, std::string* serialized_knowledge_result) const {
2371 return knowledge_engine_ &&
2372 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
2373 }
2374
2375 } // namespace libtextclassifier3
2376