1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "utils/base/statusor.h"
23 
24 #if !defined(TC3_DISABLE_LUA)
25 #include "actions/lua-actions.h"
26 #endif
27 #include "actions/ngram-model.h"
28 #include "actions/tflite-sensitive-model.h"
29 #include "actions/types.h"
30 #include "actions/utils.h"
31 #include "actions/zlib-utils.h"
32 #include "annotator/collections.h"
33 #include "utils/base/logging.h"
34 #if !defined(TC3_DISABLE_LUA)
35 #include "utils/lua-utils.h"
36 #endif
37 #include "utils/normalization.h"
38 #include "utils/optional.h"
39 #include "utils/strings/split.h"
40 #include "utils/strings/stringpiece.h"
41 #include "utils/strings/utf8.h"
42 #include "utils/utf8/unicodetext.h"
43 #include "tensorflow/lite/string_util.h"
44 
45 namespace libtextclassifier3 {
46 
47 constexpr float kDefaultFloat = 0.0;
48 constexpr bool kDefaultBool = false;
49 constexpr int kDefaultInt = 1;
50 
51 namespace {
52 
LoadAndVerifyModel(const uint8_t * addr,int size)53 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
54   flatbuffers::Verifier verifier(addr, size);
55   if (VerifyActionsModelBuffer(verifier)) {
56     return GetActionsModel(addr);
57   } else {
58     return nullptr;
59   }
60 }
61 
62 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)63 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
64                  const T default_value) {
65   if (values == nullptr) {
66     return default_value;
67   }
68   return values->GetField<T>(field_offset, default_value);
69 }
70 
71 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)72 int NumMessagesToConsider(const Conversation& conversation,
73                           const int max_conversation_history_length) {
74   return ((max_conversation_history_length < 0 ||
75            conversation.messages.size() < max_conversation_history_length)
76               ? conversation.messages.size()
77               : max_conversation_history_length);
78 }
79 
80 template <typename T>
PadOrTruncateToTargetLength(const std::vector<T> & inputs,const int max_length,const T pad_value)81 std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
82                                            const int max_length,
83                                            const T pad_value) {
84   if (inputs.size() >= max_length) {
85     return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
86   } else {
87     std::vector<T> result;
88     result.reserve(max_length);
89     result.insert(result.begin(), inputs.begin(), inputs.end());
90     result.insert(result.end(), max_length - inputs.size(), pad_value);
91     return result;
92   }
93 }
94 
95 template <typename T>
SetVectorOrScalarAsModelInput(const int param_index,const Variant & param_value,tflite::Interpreter * interpreter,const std::unique_ptr<const TfLiteModelExecutor> & model_executor)96 void SetVectorOrScalarAsModelInput(
97     const int param_index, const Variant& param_value,
98     tflite::Interpreter* interpreter,
99     const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
100   if (param_value.Has<std::vector<T>>()) {
101     model_executor->SetInput<T>(
102         param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
103   } else if (param_value.Has<T>()) {
104     model_executor->SetInput<float>(param_index, param_value.Value<T>(),
105                                     interpreter);
106   } else {
107     TC3_LOG(ERROR) << "Variant type error!";
108   }
109 }
110 }  // namespace
111 
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)112 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
113     const uint8_t* buffer, const int size, const UniLib* unilib,
114     const std::string& triggering_preconditions_overlay) {
115   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
116   const ActionsModel* model = LoadAndVerifyModel(buffer, size);
117   if (model == nullptr) {
118     return nullptr;
119   }
120   actions->model_ = model;
121   actions->SetOrCreateUnilib(unilib);
122   actions->triggering_preconditions_overlay_buffer_ =
123       triggering_preconditions_overlay;
124   if (!actions->ValidateAndInitialize()) {
125     return nullptr;
126   }
127   return actions;
128 }
129 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)130 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
131     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
132     const std::string& triggering_preconditions_overlay) {
133   if (!mmap->handle().ok()) {
134     TC3_VLOG(1) << "Mmap failed.";
135     return nullptr;
136   }
137   const ActionsModel* model = LoadAndVerifyModel(
138       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
139       mmap->handle().num_bytes());
140   if (!model) {
141     TC3_LOG(ERROR) << "Model verification failed.";
142     return nullptr;
143   }
144   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
145   actions->model_ = model;
146   actions->mmap_ = std::move(mmap);
147   actions->SetOrCreateUnilib(unilib);
148   actions->triggering_preconditions_overlay_buffer_ =
149       triggering_preconditions_overlay;
150   if (!actions->ValidateAndInitialize()) {
151     return nullptr;
152   }
153   return actions;
154 }
155 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)156 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
157     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
158     std::unique_ptr<UniLib> unilib,
159     const std::string& triggering_preconditions_overlay) {
160   if (!mmap->handle().ok()) {
161     TC3_VLOG(1) << "Mmap failed.";
162     return nullptr;
163   }
164   const ActionsModel* model = LoadAndVerifyModel(
165       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
166       mmap->handle().num_bytes());
167   if (!model) {
168     TC3_LOG(ERROR) << "Model verification failed.";
169     return nullptr;
170   }
171   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
172   actions->model_ = model;
173   actions->mmap_ = std::move(mmap);
174   actions->owned_unilib_ = std::move(unilib);
175   actions->unilib_ = actions->owned_unilib_.get();
176   actions->triggering_preconditions_overlay_buffer_ =
177       triggering_preconditions_overlay;
178   if (!actions->ValidateAndInitialize()) {
179     return nullptr;
180   }
181   return actions;
182 }
183 
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)184 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
185     const int fd, const int offset, const int size, const UniLib* unilib,
186     const std::string& triggering_preconditions_overlay) {
187   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
188   if (offset >= 0 && size >= 0) {
189     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
190   } else {
191     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
192   }
193   return FromScopedMmap(std::move(mmap), unilib,
194                         triggering_preconditions_overlay);
195 }
196 
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)197 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
198     const int fd, const int offset, const int size,
199     std::unique_ptr<UniLib> unilib,
200     const std::string& triggering_preconditions_overlay) {
201   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
202   if (offset >= 0 && size >= 0) {
203     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
204   } else {
205     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
206   }
207   return FromScopedMmap(std::move(mmap), std::move(unilib),
208                         triggering_preconditions_overlay);
209 }
210 
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)211 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
212     const int fd, const UniLib* unilib,
213     const std::string& triggering_preconditions_overlay) {
214   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
215       new libtextclassifier3::ScopedMmap(fd));
216   return FromScopedMmap(std::move(mmap), unilib,
217                         triggering_preconditions_overlay);
218 }
219 
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)220 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
221     const int fd, std::unique_ptr<UniLib> unilib,
222     const std::string& triggering_preconditions_overlay) {
223   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
224       new libtextclassifier3::ScopedMmap(fd));
225   return FromScopedMmap(std::move(mmap), std::move(unilib),
226                         triggering_preconditions_overlay);
227 }
228 
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)229 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
230     const std::string& path, const UniLib* unilib,
231     const std::string& triggering_preconditions_overlay) {
232   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
233       new libtextclassifier3::ScopedMmap(path));
234   return FromScopedMmap(std::move(mmap), unilib,
235                         triggering_preconditions_overlay);
236 }
237 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)238 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
239     const std::string& path, std::unique_ptr<UniLib> unilib,
240     const std::string& triggering_preconditions_overlay) {
241   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
242       new libtextclassifier3::ScopedMmap(path));
243   return FromScopedMmap(std::move(mmap), std::move(unilib),
244                         triggering_preconditions_overlay);
245 }
246 
SetOrCreateUnilib(const UniLib * unilib)247 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
248   if (unilib != nullptr) {
249     unilib_ = unilib;
250   } else {
251     owned_unilib_.reset(new UniLib);
252     unilib_ = owned_unilib_.get();
253   }
254 }
255 
ValidateAndInitialize()256 bool ActionsSuggestions::ValidateAndInitialize() {
257   if (model_ == nullptr) {
258     TC3_LOG(ERROR) << "No model specified.";
259     return false;
260   }
261 
262   if (model_->smart_reply_action_type() == nullptr) {
263     TC3_LOG(ERROR) << "No smart reply action type specified.";
264     return false;
265   }
266 
267   if (!InitializeTriggeringPreconditions()) {
268     TC3_LOG(ERROR) << "Could not initialize preconditions.";
269     return false;
270   }
271 
272   if (model_->locales() &&
273       !ParseLocales(model_->locales()->c_str(), &locales_)) {
274     TC3_LOG(ERROR) << "Could not parse model supported locales.";
275     return false;
276   }
277 
278   if (model_->tflite_model_spec() != nullptr) {
279     model_executor_ = TfLiteModelExecutor::FromBuffer(
280         model_->tflite_model_spec()->tflite_model());
281     if (!model_executor_) {
282       TC3_LOG(ERROR) << "Could not initialize model executor.";
283       return false;
284     }
285   }
286 
287   // Gather annotation entities for the rules.
288   if (model_->annotation_actions_spec() != nullptr &&
289       model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
290     for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
291          *model_->annotation_actions_spec()->annotation_mapping()) {
292       annotation_entity_types_.insert(mapping->annotation_collection()->str());
293     }
294   }
295 
296   if (model_->actions_entity_data_schema() != nullptr) {
297     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
298         model_->actions_entity_data_schema()->Data(),
299         model_->actions_entity_data_schema()->size());
300     if (entity_data_schema_ == nullptr) {
301       TC3_LOG(ERROR) << "Could not load entity data schema data.";
302       return false;
303     }
304 
305     entity_data_builder_.reset(
306         new MutableFlatbufferBuilder(entity_data_schema_));
307   } else {
308     entity_data_schema_ = nullptr;
309   }
310 
311   // Initialize regular expressions model.
312   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
313   regex_actions_.reset(
314       new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
315   if (!regex_actions_->InitializeRules(
316           model_->rules(), model_->low_confidence_rules(),
317           triggering_preconditions_overlay_, decompressor.get())) {
318     TC3_LOG(ERROR) << "Could not initialize regex rules.";
319     return false;
320   }
321 
322   // Setup grammar model.
323   if (model_->rules() != nullptr &&
324       model_->rules()->grammar_rules() != nullptr) {
325     grammar_actions_.reset(new GrammarActions(
326         unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
327         model_->smart_reply_action_type()->str()));
328 
329     // Gather annotation entities for the grammars.
330     if (auto annotation_nt = model_->rules()
331                                  ->grammar_rules()
332                                  ->rules()
333                                  ->nonterminals()
334                                  ->annotation_nt()) {
335       for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
336            *annotation_nt) {
337         annotation_entity_types_.insert(entry->key()->str());
338       }
339     }
340   }
341 
342 #if !defined(TC3_DISABLE_LUA)
343   std::string actions_script;
344   if (GetUncompressedString(model_->lua_actions_script(),
345                             model_->compressed_lua_actions_script(),
346                             decompressor.get(), &actions_script) &&
347       !actions_script.empty()) {
348     if (!Compile(actions_script, &lua_bytecode_)) {
349       TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
350       return false;
351     }
352   }
353 #endif  // TC3_DISABLE_LUA
354 
355   if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
356             model_->ranking_options(), decompressor.get(),
357             model_->smart_reply_action_type()->str()))) {
358     TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
359     return false;
360   }
361 
362   // Create feature processor if specified.
363   const ActionsTokenFeatureProcessorOptions* options =
364       model_->feature_processor_options();
365   if (options != nullptr) {
366     if (options->tokenizer_options() == nullptr) {
367       TC3_LOG(ERROR) << "No tokenizer options specified.";
368       return false;
369     }
370 
371     feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
372     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
373         options->embedding_model(), options->embedding_size(),
374         options->embedding_quantization_bits());
375 
376     if (embedding_executor_ == nullptr) {
377       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
378       return false;
379     }
380 
381     // Cache embedding of padding, start and end token.
382     if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
383         !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
384         !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
385       TC3_LOG(ERROR) << "Could not precompute token embeddings.";
386       return false;
387     }
388     token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
389   }
390 
391   // Create low confidence model if specified.
392   if (model_->low_confidence_ngram_model() != nullptr) {
393     sensitive_model_ = NGramSensitiveModel::Create(
394         unilib_, model_->low_confidence_ngram_model(),
395         feature_processor_ == nullptr ? nullptr
396                                       : feature_processor_->tokenizer());
397     if (sensitive_model_ == nullptr) {
398       TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
399       return false;
400     }
401   }
402   if (model_->low_confidence_tflite_model() != nullptr) {
403     sensitive_model_ =
404         TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
405     if (sensitive_model_ == nullptr) {
406       TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
407       return false;
408     }
409   }
410 
411   return true;
412 }
413 
InitializeTriggeringPreconditions()414 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
415   triggering_preconditions_overlay_ =
416       LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
417           triggering_preconditions_overlay_buffer_);
418 
419   if (triggering_preconditions_overlay_ == nullptr &&
420       !triggering_preconditions_overlay_buffer_.empty()) {
421     TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
422     return false;
423   }
424   const flatbuffers::Table* overlay =
425       reinterpret_cast<const flatbuffers::Table*>(
426           triggering_preconditions_overlay_);
427   const TriggeringPreconditions* defaults = model_->preconditions();
428   if (defaults == nullptr) {
429     TC3_LOG(ERROR) << "No triggering conditions specified.";
430     return false;
431   }
432 
433   preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
434       overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
435       defaults->min_smart_reply_triggering_score());
436   preconditions_.max_sensitive_topic_score = ValueOrDefault(
437       overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
438       defaults->max_sensitive_topic_score());
439   preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
440       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
441       defaults->suppress_on_sensitive_topic());
442   preconditions_.min_input_length =
443       ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
444                      defaults->min_input_length());
445   preconditions_.max_input_length =
446       ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
447                      defaults->max_input_length());
448   preconditions_.min_locale_match_fraction = ValueOrDefault(
449       overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
450       defaults->min_locale_match_fraction());
451   preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
452       overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
453       defaults->handle_missing_locale_as_supported());
454   preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
455       overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
456       defaults->handle_unknown_locale_as_supported());
457   preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
458       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
459       defaults->suppress_on_low_confidence_input());
460   preconditions_.min_reply_score_threshold = ValueOrDefault(
461       overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
462       defaults->min_reply_score_threshold());
463 
464   return true;
465 }
466 
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const467 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
468                                       std::vector<float>* embedding) const {
469   return feature_processor_->AppendFeatures(
470       {token_id},
471       /*dense_features=*/{}, embedding_executor_.get(), embedding);
472 }
473 
Tokenize(const std::vector<std::string> & context) const474 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
475     const std::vector<std::string>& context) const {
476   std::vector<std::vector<Token>> tokens;
477   tokens.reserve(context.size());
478   for (const std::string& message : context) {
479     tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
480   }
481   return tokens;
482 }
483 
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const484 bool ActionsSuggestions::EmbedTokensPerMessage(
485     const std::vector<std::vector<Token>>& tokens,
486     std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
487   const int num_messages = tokens.size();
488   *max_num_tokens_per_message = 0;
489   for (int i = 0; i < num_messages; i++) {
490     const int num_message_tokens = tokens[i].size();
491     if (num_message_tokens > *max_num_tokens_per_message) {
492       *max_num_tokens_per_message = num_message_tokens;
493     }
494   }
495 
496   if (model_->feature_processor_options()->min_num_tokens_per_message() >
497       *max_num_tokens_per_message) {
498     *max_num_tokens_per_message =
499         model_->feature_processor_options()->min_num_tokens_per_message();
500   }
501   if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
502       *max_num_tokens_per_message >
503           model_->feature_processor_options()->max_num_tokens_per_message()) {
504     *max_num_tokens_per_message =
505         model_->feature_processor_options()->max_num_tokens_per_message();
506   }
507 
508   // Embed all tokens and add paddings to pad tokens of each message to the
509   // maximum number of tokens in a message of the conversation.
510   // If a number of tokens is specified in the model config, tokens at the
511   // beginning of a message are dropped if they don't fit in the limit.
512   for (int i = 0; i < num_messages; i++) {
513     const int start =
514         std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
515     for (int pos = start; pos < tokens[i].size(); pos++) {
516       if (!feature_processor_->AppendTokenFeatures(
517               tokens[i][pos], embedding_executor_.get(), embeddings)) {
518         TC3_LOG(ERROR) << "Could not run token feature extractor.";
519         return false;
520       }
521     }
522     // Add padding.
523     for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
524       embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
525                          embedded_padding_token_.end());
526     }
527   }
528 
529   return true;
530 }
531 
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const532 bool ActionsSuggestions::EmbedAndFlattenTokens(
533     const std::vector<std::vector<Token>>& tokens,
534     std::vector<float>* embeddings, int* total_token_count) const {
535   const int num_messages = tokens.size();
536   int start_message = 0;
537   int message_token_offset = 0;
538 
539   // If a maximum model input length is specified, we need to check how
540   // much we need to trim at the start.
541   const int max_num_total_tokens =
542       model_->feature_processor_options()->max_num_total_tokens();
543   if (max_num_total_tokens > 0) {
544     int total_tokens = 0;
545     start_message = num_messages - 1;
546     for (; start_message >= 0; start_message--) {
547       // Tokens of the message + start and end token.
548       const int num_message_tokens = tokens[start_message].size() + 2;
549       total_tokens += num_message_tokens;
550 
551       // Check whether we exhausted the budget.
552       if (total_tokens >= max_num_total_tokens) {
553         message_token_offset = total_tokens - max_num_total_tokens;
554         break;
555       }
556     }
557   }
558 
559   // Add embeddings.
560   *total_token_count = 0;
561   for (int i = start_message; i < num_messages; i++) {
562     if (message_token_offset == 0) {
563       ++(*total_token_count);
564       // Add `start message` token.
565       embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
566                          embedded_start_token_.end());
567     }
568 
569     for (int pos = std::max(0, message_token_offset - 1);
570          pos < tokens[i].size(); pos++) {
571       ++(*total_token_count);
572       if (!feature_processor_->AppendTokenFeatures(
573               tokens[i][pos], embedding_executor_.get(), embeddings)) {
574         TC3_LOG(ERROR) << "Could not run token feature extractor.";
575         return false;
576       }
577     }
578 
579     // Add `end message` token.
580     ++(*total_token_count);
581     embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
582                        embedded_end_token_.end());
583 
584     // Reset for the subsequent messages.
585     message_token_offset = 0;
586   }
587 
588   // Add optional padding.
589   const int min_num_total_tokens =
590       model_->feature_processor_options()->min_num_total_tokens();
591   for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
592     embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
593                        embedded_padding_token_.end());
594   }
595 
596   return true;
597 }
598 
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const599 bool ActionsSuggestions::AllocateInput(const int conversation_length,
600                                        const int max_tokens,
601                                        const int total_token_count,
602                                        tflite::Interpreter* interpreter) const {
603   if (model_->tflite_model_spec()->resize_inputs()) {
604     if (model_->tflite_model_spec()->input_context() >= 0) {
605       interpreter->ResizeInputTensor(
606           interpreter->inputs()[model_->tflite_model_spec()->input_context()],
607           {1, conversation_length});
608     }
609     if (model_->tflite_model_spec()->input_user_id() >= 0) {
610       interpreter->ResizeInputTensor(
611           interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
612           {1, conversation_length});
613     }
614     if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
615       interpreter->ResizeInputTensor(
616           interpreter
617               ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
618           {1, conversation_length});
619     }
620     if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
621       interpreter->ResizeInputTensor(
622           interpreter
623               ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
624           {conversation_length, 1});
625     }
626     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
627       interpreter->ResizeInputTensor(
628           interpreter
629               ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
630           {conversation_length, max_tokens, token_embedding_size_});
631     }
632     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
633       interpreter->ResizeInputTensor(
634           interpreter->inputs()[model_->tflite_model_spec()
635                                     ->input_flattened_token_embeddings()],
636           {1, total_token_count});
637     }
638   }
639 
640   return interpreter->AllocateTensors() == kTfLiteOk;
641 }
642 
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const643 bool ActionsSuggestions::SetupModelInput(
644     const std::vector<std::string>& context, const std::vector<int>& user_ids,
645     const std::vector<float>& time_diffs, const int num_suggestions,
646     const ActionSuggestionOptions& options,
647     tflite::Interpreter* interpreter) const {
648   // Compute token embeddings.
649   std::vector<std::vector<Token>> tokens;
650   std::vector<float> token_embeddings;
651   std::vector<float> flattened_token_embeddings;
652   int max_tokens = 0;
653   int total_token_count = 0;
654   if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
655       model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
656       model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
657     if (feature_processor_ == nullptr) {
658       TC3_LOG(ERROR) << "No feature processor specified.";
659       return false;
660     }
661 
662     // Tokenize the messages in the conversation.
663     tokens = Tokenize(context);
664     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
665       if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
666         TC3_LOG(ERROR) << "Could not extract token features.";
667         return false;
668       }
669     }
670     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
671       if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
672                                  &total_token_count)) {
673         TC3_LOG(ERROR) << "Could not extract token features.";
674         return false;
675       }
676     }
677   }
678 
679   if (!AllocateInput(context.size(), max_tokens, total_token_count,
680                      interpreter)) {
681     TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
682     return false;
683   }
684   if (model_->tflite_model_spec()->input_context() >= 0) {
685     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
686       model_executor_->SetInput<std::string>(
687           model_->tflite_model_spec()->input_context(),
688           PadOrTruncateToTargetLength(
689               context, model_->tflite_model_spec()->input_length_to_pad(),
690               std::string("")),
691           interpreter);
692     } else {
693       model_executor_->SetInput<std::string>(
694           model_->tflite_model_spec()->input_context(), context, interpreter);
695     }
696   }
697   if (model_->tflite_model_spec()->input_context_length() >= 0) {
698     model_executor_->SetInput<int>(
699         model_->tflite_model_spec()->input_context_length(), context.size(),
700         interpreter);
701   }
702   if (model_->tflite_model_spec()->input_user_id() >= 0) {
703     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
704       model_executor_->SetInput<int>(
705           model_->tflite_model_spec()->input_user_id(),
706           PadOrTruncateToTargetLength(
707               user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
708           interpreter);
709     } else {
710       model_executor_->SetInput<int>(
711           model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
712     }
713   }
714   if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
715     model_executor_->SetInput<int>(
716         model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
717         interpreter);
718   }
719   if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
720     model_executor_->SetInput<float>(
721         model_->tflite_model_spec()->input_time_diffs(), time_diffs,
722         interpreter);
723   }
724   if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
725     std::vector<int> num_tokens_per_message(tokens.size());
726     for (int i = 0; i < tokens.size(); i++) {
727       num_tokens_per_message[i] = tokens[i].size();
728     }
729     model_executor_->SetInput<int>(
730         model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
731         interpreter);
732   }
733   if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
734     model_executor_->SetInput<float>(
735         model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
736         interpreter);
737   }
738   if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
739     model_executor_->SetInput<float>(
740         model_->tflite_model_spec()->input_flattened_token_embeddings(),
741         flattened_token_embeddings, interpreter);
742   }
743   // Set up additional input parameters.
744   if (const auto* input_name_index =
745           model_->tflite_model_spec()->input_name_index()) {
746     const std::unordered_map<std::string, Variant>& model_parameters =
747         options.model_parameters;
748     for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
749          *input_name_index) {
750       const std::string param_name = entry->key()->str();
751       const int param_index = entry->value();
752       const TfLiteType param_type =
753           interpreter->tensor(interpreter->inputs()[param_index])->type;
754       const auto param_value_it = model_parameters.find(param_name);
755       const bool has_value = param_value_it != model_parameters.end();
756       switch (param_type) {
757         case kTfLiteFloat32:
758           if (has_value) {
759             SetVectorOrScalarAsModelInput<float>(param_index,
760                                                  param_value_it->second,
761                                                  interpreter, model_executor_);
762           } else {
763             model_executor_->SetInput<float>(param_index, kDefaultFloat,
764                                              interpreter);
765           }
766           break;
767         case kTfLiteInt32:
768           if (has_value) {
769             SetVectorOrScalarAsModelInput<int32_t>(
770                 param_index, param_value_it->second, interpreter,
771                 model_executor_);
772           } else {
773             model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
774                                                interpreter);
775           }
776           break;
777         case kTfLiteInt64:
778           model_executor_->SetInput<int64_t>(
779               param_index,
780               has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
781               interpreter);
782           break;
783         case kTfLiteUInt8:
784           model_executor_->SetInput<uint8_t>(
785               param_index,
786               has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
787               interpreter);
788           break;
789         case kTfLiteInt8:
790           model_executor_->SetInput<int8_t>(
791               param_index,
792               has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
793               interpreter);
794           break;
795         case kTfLiteBool:
796           model_executor_->SetInput<bool>(
797               param_index,
798               has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
799               interpreter);
800           break;
801         default:
802           TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
803                          << param_name;
804       }
805     }
806   }
807   return true;
808 }
809 
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,ActionsSuggestionsResponse * response) const810 void ActionsSuggestions::PopulateTextReplies(
811     const tflite::Interpreter* interpreter, int suggestion_index,
812     int score_index, const std::string& type,
813     ActionsSuggestionsResponse* response) const {
814   const std::vector<tflite::StringRef> replies =
815       model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
816   const TensorView<float> scores =
817       model_executor_->OutputView<float>(score_index, interpreter);
818   for (int i = 0; i < replies.size(); i++) {
819     if (replies[i].len == 0) {
820       continue;
821     }
822     const float score = scores.data()[i];
823     if (score < preconditions_.min_reply_score_threshold) {
824       continue;
825     }
826     response->actions.push_back(
827         {std::string(replies[i].str, replies[i].len), type, score});
828   }
829 }
830 
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const831 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
832     const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
833   std::unique_ptr<MutableFlatbuffer> entity_data =
834       entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
835                                       : nullptr;
836   FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
837 }
838 
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const839 void ActionsSuggestions::PopulateIntentTriggering(
840     const tflite::Interpreter* interpreter, int suggestion_index,
841     int score_index, const ActionSuggestionSpec* task_spec,
842     ActionsSuggestionsResponse* response) const {
843   if (!task_spec || task_spec->type()->size() == 0) {
844     TC3_LOG(ERROR)
845         << "Task type for intent (action) triggering cannot be empty!";
846     return;
847   }
848   const TensorView<bool> intent_prediction =
849       model_executor_->OutputView<bool>(suggestion_index, interpreter);
850   const TensorView<float> intent_scores =
851       model_executor_->OutputView<float>(score_index, interpreter);
852   // Two result corresponding to binary triggering case.
853   TC3_CHECK_EQ(intent_prediction.size(), 2);
854   TC3_CHECK_EQ(intent_scores.size(), 2);
855   // We rely on in-graph thresholding logic so at this point the results
856   // have been ranked properly according to threshold.
857   const bool triggering = intent_prediction.data()[0];
858   const float trigger_score = intent_scores.data()[0];
859 
860   if (triggering) {
861     ActionSuggestion suggestion;
862     std::unique_ptr<MutableFlatbuffer> entity_data =
863         entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
864                                         : nullptr;
865     FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
866     suggestion.score = trigger_score;
867     response->actions.push_back(std::move(suggestion));
868   }
869 }
870 
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const871 bool ActionsSuggestions::ReadModelOutput(
872     tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
873     ActionsSuggestionsResponse* response) const {
874   // Read sensitivity and triggering score predictions.
875   if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
876     const TensorView<float> triggering_score =
877         model_executor_->OutputView<float>(
878             model_->tflite_model_spec()->output_triggering_score(),
879             interpreter);
880     if (!triggering_score.is_valid() || triggering_score.size() == 0) {
881       TC3_LOG(ERROR) << "Could not compute triggering score.";
882       return false;
883     }
884     response->triggering_score = triggering_score.data()[0];
885     response->output_filtered_min_triggering_score =
886         (response->triggering_score <
887          preconditions_.min_smart_reply_triggering_score);
888   }
889   if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
890     const TensorView<float> sensitive_topic_score =
891         model_executor_->OutputView<float>(
892             model_->tflite_model_spec()->output_sensitive_topic_score(),
893             interpreter);
894     if (!sensitive_topic_score.is_valid() ||
895         sensitive_topic_score.dim(0) != 1) {
896       TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
897       return false;
898     }
899     response->sensitivity_score = sensitive_topic_score.data()[0];
900     response->is_sensitive = (response->sensitivity_score >
901                               preconditions_.max_sensitive_topic_score);
902   }
903 
904   // Suppress model outputs.
905   if (response->is_sensitive) {
906     return true;
907   }
908 
909   // Read smart reply predictions.
910   if (!response->output_filtered_min_triggering_score &&
911       model_->tflite_model_spec()->output_replies() >= 0) {
912     PopulateTextReplies(interpreter,
913                         model_->tflite_model_spec()->output_replies(),
914                         model_->tflite_model_spec()->output_replies_scores(),
915                         model_->smart_reply_action_type()->str(), response);
916   }
917 
918   // Read actions suggestions.
919   if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
920     const TensorView<float> actions_scores = model_executor_->OutputView<float>(
921         model_->tflite_model_spec()->output_actions_scores(), interpreter);
922     for (int i = 0; i < model_->action_type()->size(); i++) {
923       const ActionTypeOptions* action_type = model_->action_type()->Get(i);
924       // Skip disabled action classes, such as the default other category.
925       if (!action_type->enabled()) {
926         continue;
927       }
928       const float score = actions_scores.data()[i];
929       if (score < action_type->min_triggering_score()) {
930         continue;
931       }
932 
933       // Create action from model output.
934       ActionSuggestion suggestion;
935       suggestion.type = action_type->name()->str();
936       std::unique_ptr<MutableFlatbuffer> entity_data =
937           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
938                                           : nullptr;
939       FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
940       suggestion.score = score;
941       response->actions.push_back(std::move(suggestion));
942     }
943   }
944 
945   // Read multi-task predictions and construct the result properly.
946   if (const auto* prediction_metadata =
947           model_->tflite_model_spec()->prediction_metadata()) {
948     for (const PredictionMetadata* metadata : *prediction_metadata) {
949       const ActionSuggestionSpec* task_spec = metadata->task_spec();
950       const int suggestions_index = metadata->output_suggestions();
951       const int suggestions_scores_index =
952           metadata->output_suggestions_scores();
953       switch (metadata->prediction_type()) {
954         case PredictionType_NEXT_MESSAGE_PREDICTION:
955           if (!task_spec || task_spec->type()->size() == 0) {
956             TC3_LOG(WARNING) << "Task type not provided, use default "
957                                 "smart_reply_action_type!";
958           }
959           PopulateTextReplies(
960               interpreter, suggestions_index, suggestions_scores_index,
961               task_spec ? task_spec->type()->str()
962                         : model_->smart_reply_action_type()->str(),
963               response);
964           break;
965         case PredictionType_INTENT_TRIGGERING:
966           PopulateIntentTriggering(interpreter, suggestions_index,
967                                    suggestions_scores_index, task_spec,
968                                    response);
969           break;
970         default:
971           TC3_LOG(ERROR) << "Unsupported prediction type!";
972           return false;
973       }
974     }
975   }
976 
977   return true;
978 }
979 
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const980 bool ActionsSuggestions::SuggestActionsFromModel(
981     const Conversation& conversation, const int num_messages,
982     const ActionSuggestionOptions& options,
983     ActionsSuggestionsResponse* response,
984     std::unique_ptr<tflite::Interpreter>* interpreter) const {
985   TC3_CHECK_LE(num_messages, conversation.messages.size());
986 
987   if (sensitive_model_ != nullptr &&
988       sensitive_model_->EvalConversation(conversation, num_messages).first) {
989     response->is_sensitive = true;
990     return true;
991   }
992 
993   if (!model_executor_) {
994     return true;
995   }
996   *interpreter = model_executor_->CreateInterpreter();
997 
998   if (!*interpreter) {
999     TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
1000                       "actions suggestions model.";
1001     return false;
1002   }
1003 
1004   std::vector<std::string> context;
1005   std::vector<int> user_ids;
1006   std::vector<float> time_diffs;
1007   context.reserve(num_messages);
1008   user_ids.reserve(num_messages);
1009   time_diffs.reserve(num_messages);
1010 
1011   // Gather last `num_messages` messages from the conversation.
1012   int64 last_message_reference_time_ms_utc = 0;
1013   const float second_in_ms = 1000;
1014   for (int i = conversation.messages.size() - num_messages;
1015        i < conversation.messages.size(); i++) {
1016     const ConversationMessage& message = conversation.messages[i];
1017     context.push_back(message.text);
1018     user_ids.push_back(message.user_id);
1019 
1020     float time_diff_secs = 0;
1021     if (message.reference_time_ms_utc != 0 &&
1022         last_message_reference_time_ms_utc != 0) {
1023       time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
1024                                        last_message_reference_time_ms_utc) /
1025                                           second_in_ms);
1026     }
1027     if (message.reference_time_ms_utc != 0) {
1028       last_message_reference_time_ms_utc = message.reference_time_ms_utc;
1029     }
1030     time_diffs.push_back(time_diff_secs);
1031   }
1032 
1033   if (!SetupModelInput(context, user_ids, time_diffs,
1034                        /*num_suggestions=*/model_->num_smart_replies(), options,
1035                        interpreter->get())) {
1036     TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1037     return false;
1038   }
1039 
1040   if ((*interpreter)->Invoke() != kTfLiteOk) {
1041     TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1042     return false;
1043   }
1044 
1045   return ReadModelOutput(interpreter->get(), options, response);
1046 }
1047 
SuggestActionsFromConversationIntentDetection(const Conversation & conversation,const ActionSuggestionOptions & options,std::vector<ActionSuggestion> * actions) const1048 Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
1049     const Conversation& conversation, const ActionSuggestionOptions& options,
1050     std::vector<ActionSuggestion>* actions) const {
1051   TC3_ASSIGN_OR_RETURN(
1052       std::vector<ActionSuggestion> new_actions,
1053       conversation_intent_detection_->SuggestActions(conversation, options));
1054   for (auto& action : new_actions) {
1055     actions->push_back(std::move(action));
1056   }
1057   return Status::OK;
1058 }
1059 
AnnotationOptionsForMessage(const ConversationMessage & message) const1060 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1061     const ConversationMessage& message) const {
1062   AnnotationOptions options;
1063   options.detected_text_language_tags = message.detected_text_language_tags;
1064   options.reference_time_ms_utc = message.reference_time_ms_utc;
1065   options.reference_timezone = message.reference_timezone;
1066   options.annotation_usecase =
1067       model_->annotation_actions_spec()->annotation_usecase();
1068   options.is_serialized_entity_data_enabled =
1069       model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1070   options.entity_types = annotation_entity_types_;
1071   return options;
1072 }
1073 
1074 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1075 Conversation ActionsSuggestions::AnnotateConversation(
1076     const Conversation& conversation, const Annotator* annotator) const {
1077   if (annotator == nullptr) {
1078     return conversation;
1079   }
1080   const int num_messages_grammar =
1081       ((model_->rules() && model_->rules()->grammar_rules() &&
1082         model_->rules()
1083             ->grammar_rules()
1084             ->rules()
1085             ->nonterminals()
1086             ->annotation_nt())
1087            ? 1
1088            : 0);
1089   const int num_messages_mapping =
1090       (model_->annotation_actions_spec()
1091            ? std::max(model_->annotation_actions_spec()
1092                           ->max_history_from_any_person(),
1093                       model_->annotation_actions_spec()
1094                           ->max_history_from_last_person())
1095            : 0);
1096   const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1097   if (num_messages == 0) {
1098     // No annotations are used.
1099     return conversation;
1100   }
1101   Conversation annotated_conversation = conversation;
1102   for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1103        i < num_messages && message_index >= 0; i++, message_index--) {
1104     ConversationMessage* message =
1105         &annotated_conversation.messages[message_index];
1106     if (message->annotations.empty()) {
1107       message->annotations = annotator->Annotate(
1108           message->text, AnnotationOptionsForMessage(*message));
1109       ConvertDatetimeToTime(&message->annotations);
1110     }
1111   }
1112   return annotated_conversation;
1113 }
1114 
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1115 void ActionsSuggestions::SuggestActionsFromAnnotations(
1116     const Conversation& conversation,
1117     std::vector<ActionSuggestion>* actions) const {
1118   if (model_->annotation_actions_spec() == nullptr ||
1119       model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1120       model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1121     return;
1122   }
1123 
1124   // Create actions based on the annotations.
1125   const int max_from_any_person =
1126       model_->annotation_actions_spec()->max_history_from_any_person();
1127   const int max_from_last_person =
1128       model_->annotation_actions_spec()->max_history_from_last_person();
1129   const int last_person = conversation.messages.back().user_id;
1130 
1131   int num_messages_last_person = 0;
1132   int num_messages_any_person = 0;
1133   bool all_from_last_person = true;
1134   for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1135        message_index--) {
1136     const ConversationMessage& message = conversation.messages[message_index];
1137     std::vector<AnnotatedSpan> annotations = message.annotations;
1138 
1139     // Update how many messages we have processed from the last person in the
1140     // conversation and from any person in the conversation.
1141     num_messages_any_person++;
1142     if (all_from_last_person && message.user_id == last_person) {
1143       num_messages_last_person++;
1144     } else {
1145       all_from_last_person = false;
1146     }
1147 
1148     if (num_messages_any_person > max_from_any_person &&
1149         (!all_from_last_person ||
1150          num_messages_last_person > max_from_last_person)) {
1151       break;
1152     }
1153 
1154     if (message.user_id == kLocalUserId) {
1155       if (model_->annotation_actions_spec()->only_until_last_sent()) {
1156         break;
1157       }
1158       if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1159         continue;
1160       }
1161     }
1162 
1163     std::vector<ActionSuggestionAnnotation> action_annotations;
1164     action_annotations.reserve(annotations.size());
1165     for (const AnnotatedSpan& annotation : annotations) {
1166       if (annotation.classification.empty()) {
1167         continue;
1168       }
1169 
1170       const ClassificationResult& classification_result =
1171           annotation.classification[0];
1172 
1173       ActionSuggestionAnnotation action_annotation;
1174       action_annotation.span = {
1175           message_index, annotation.span,
1176           UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1177               .UTF8Substring(annotation.span.first, annotation.span.second)};
1178       action_annotation.entity = classification_result;
1179       action_annotation.name = classification_result.collection;
1180       action_annotations.push_back(std::move(action_annotation));
1181     }
1182 
1183     if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1184       // Create actions only for deduplicated annotations.
1185       for (const int annotation_id :
1186            DeduplicateAnnotations(action_annotations)) {
1187         SuggestActionsFromAnnotation(
1188             message_index, action_annotations[annotation_id], actions);
1189       }
1190     } else {
1191       // Create actions for all annotations.
1192       for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1193         SuggestActionsFromAnnotation(message_index, annotation, actions);
1194       }
1195     }
1196   }
1197 }
1198 
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1199 void ActionsSuggestions::SuggestActionsFromAnnotation(
1200     const int message_index, const ActionSuggestionAnnotation& annotation,
1201     std::vector<ActionSuggestion>* actions) const {
1202   for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1203        *model_->annotation_actions_spec()->annotation_mapping()) {
1204     if (annotation.entity.collection ==
1205         mapping->annotation_collection()->str()) {
1206       if (annotation.entity.score < mapping->min_annotation_score()) {
1207         continue;
1208       }
1209 
1210       std::unique_ptr<MutableFlatbuffer> entity_data =
1211           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1212                                           : nullptr;
1213 
1214       // Set annotation text as (additional) entity data field.
1215       if (mapping->entity_field() != nullptr) {
1216         TC3_CHECK_NE(entity_data, nullptr);
1217 
1218         UnicodeText normalized_annotation_text =
1219             UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1220 
1221         // Apply normalization if specified.
1222         if (mapping->normalization_options() != nullptr) {
1223           normalized_annotation_text =
1224               NormalizeText(*unilib_, mapping->normalization_options(),
1225                             normalized_annotation_text);
1226         }
1227 
1228         entity_data->ParseAndSet(mapping->entity_field(),
1229                                  normalized_annotation_text.ToUTF8String());
1230       }
1231 
1232       ActionSuggestion suggestion;
1233       FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1234       if (mapping->use_annotation_score()) {
1235         suggestion.score = annotation.entity.score;
1236       }
1237       suggestion.annotations = {annotation};
1238       actions->push_back(std::move(suggestion));
1239     }
1240   }
1241 }
1242 
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1243 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1244     const std::vector<ActionSuggestionAnnotation>& annotations) const {
1245   std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1246 
1247   for (int i = 0; i < annotations.size(); i++) {
1248     const std::pair<std::string, std::string> key = {annotations[i].name,
1249                                                      annotations[i].span.text};
1250     auto entry = deduplicated_annotations.find(key);
1251     if (entry != deduplicated_annotations.end()) {
1252       // Kepp the annotation with the higher score.
1253       if (annotations[entry->second].entity.score <
1254           annotations[i].entity.score) {
1255         entry->second = i;
1256       }
1257       continue;
1258     }
1259     deduplicated_annotations.insert(entry, {key, i});
1260   }
1261 
1262   std::vector<int> result;
1263   result.reserve(deduplicated_annotations.size());
1264   for (const auto& key_and_annotation : deduplicated_annotations) {
1265     result.push_back(key_and_annotation.second);
1266   }
1267   return result;
1268 }
1269 
1270 #if !defined(TC3_DISABLE_LUA)
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1271 bool ActionsSuggestions::SuggestActionsFromLua(
1272     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1273     const tflite::Interpreter* interpreter,
1274     const reflection::Schema* annotation_entity_data_schema,
1275     std::vector<ActionSuggestion>* actions) const {
1276   if (lua_bytecode_.empty()) {
1277     return true;
1278   }
1279 
1280   auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1281       lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1282       interpreter, entity_data_schema_, annotation_entity_data_schema);
1283   if (lua_actions == nullptr) {
1284     TC3_LOG(ERROR) << "Could not create lua actions.";
1285     return false;
1286   }
1287   return lua_actions->SuggestActions(actions);
1288 }
1289 #else
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1290 bool ActionsSuggestions::SuggestActionsFromLua(
1291     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1292     const tflite::Interpreter* interpreter,
1293     const reflection::Schema* annotation_entity_data_schema,
1294     std::vector<ActionSuggestion>* actions) const {
1295   return true;
1296 }
1297 #endif
1298 
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1299 bool ActionsSuggestions::GatherActionsSuggestions(
1300     const Conversation& conversation, const Annotator* annotator,
1301     const ActionSuggestionOptions& options,
1302     ActionsSuggestionsResponse* response) const {
1303   if (conversation.messages.empty()) {
1304     return true;
1305   }
1306 
1307   // Run annotator against messages.
1308   const Conversation annotated_conversation =
1309       AnnotateConversation(conversation, annotator);
1310 
1311   const int num_messages = NumMessagesToConsider(
1312       annotated_conversation, model_->max_conversation_history_length());
1313 
1314   if (num_messages <= 0) {
1315     TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1316     return false;
1317   }
1318 
1319   SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1320 
1321   if (grammar_actions_ != nullptr &&
1322       !grammar_actions_->SuggestActions(annotated_conversation,
1323                                         &response->actions)) {
1324     TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1325     return false;
1326   }
1327 
1328   int input_text_length = 0;
1329   int num_matching_locales = 0;
1330   for (int i = annotated_conversation.messages.size() - num_messages;
1331        i < annotated_conversation.messages.size(); i++) {
1332     input_text_length += annotated_conversation.messages[i].text.length();
1333     std::vector<Locale> message_languages;
1334     if (!ParseLocales(
1335             annotated_conversation.messages[i].detected_text_language_tags,
1336             &message_languages)) {
1337       continue;
1338     }
1339     if (Locale::IsAnyLocaleSupported(
1340             message_languages, locales_,
1341             preconditions_.handle_unknown_locale_as_supported)) {
1342       ++num_matching_locales;
1343     }
1344   }
1345 
1346   // Bail out if we are provided with too few or too much input.
1347   if (input_text_length < preconditions_.min_input_length ||
1348       (preconditions_.max_input_length >= 0 &&
1349        input_text_length > preconditions_.max_input_length)) {
1350     TC3_LOG(INFO) << "Too much or not enough input for inference.";
1351     return response;
1352   }
1353 
1354   // Bail out if the text does not look like it can be handled by the model.
1355   const float matching_fraction =
1356       static_cast<float>(num_matching_locales) / num_messages;
1357   if (matching_fraction < preconditions_.min_locale_match_fraction) {
1358     TC3_LOG(INFO) << "Not enough locale matches.";
1359     response->output_filtered_locale_mismatch = true;
1360     return true;
1361   }
1362 
1363   std::vector<const UniLib::RegexPattern*> post_check_rules;
1364   if (preconditions_.suppress_on_low_confidence_input) {
1365     if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
1366                                              num_messages, &post_check_rules)) {
1367       response->output_filtered_low_confidence = true;
1368       return true;
1369     }
1370   }
1371 
1372   std::unique_ptr<tflite::Interpreter> interpreter;
1373   if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1374                                response, &interpreter)) {
1375     TC3_LOG(ERROR) << "Could not run model.";
1376     return false;
1377   }
1378 
1379   // SuggestActionsFromModel also detects if the conversation is sensitive,
1380   // either by using the old ngram model or the new model.
1381   // Suppress all predictions if the conversation was deemed sensitive.
1382   if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
1383     return true;
1384   }
1385 
1386   if (conversation_intent_detection_) {
1387     // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
1388     auto actions = SuggestActionsFromConversationIntentDetection(
1389         annotated_conversation, options, &response->actions);
1390     if (!actions.ok()) {
1391       TC3_LOG(ERROR) << "Could not run conversation intent detection: "
1392                      << actions.error_message();
1393       return false;
1394     }
1395   }
1396 
1397   if (!SuggestActionsFromLua(
1398           annotated_conversation, model_executor_.get(), interpreter.get(),
1399           annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1400           &response->actions)) {
1401     TC3_LOG(ERROR) << "Could not suggest actions from script.";
1402     return false;
1403   }
1404 
1405   if (!regex_actions_->SuggestActions(annotated_conversation,
1406                                       entity_data_builder_.get(),
1407                                       &response->actions)) {
1408     TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1409     return false;
1410   }
1411 
1412   if (preconditions_.suppress_on_low_confidence_input &&
1413       !regex_actions_->FilterConfidenceOutput(post_check_rules,
1414                                               &response->actions)) {
1415     TC3_LOG(ERROR) << "Could not post-check actions.";
1416     return false;
1417   }
1418 
1419   return true;
1420 }
1421 
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1422 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1423     const Conversation& conversation, const Annotator* annotator,
1424     const ActionSuggestionOptions& options) const {
1425   ActionsSuggestionsResponse response;
1426 
1427   // Assert that messages are sorted correctly.
1428   for (int i = 1; i < conversation.messages.size(); i++) {
1429     if (conversation.messages[i].reference_time_ms_utc <
1430         conversation.messages[i - 1].reference_time_ms_utc) {
1431       TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1432       return response;
1433     }
1434   }
1435 
1436   // Check that messages are valid utf8.
1437   for (const ConversationMessage& message : conversation.messages) {
1438     if (message.text.size() > std::numeric_limits<int>::max()) {
1439       TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
1440       return {};
1441     }
1442 
1443     if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
1444             message.text.data(), message.text.size(), /*do_copy=*/false))) {
1445       TC3_LOG(ERROR) << "Not valid utf8 provided.";
1446       return response;
1447     }
1448   }
1449 
1450   if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1451     TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1452     response.actions.clear();
1453   } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1454                                    annotator != nullptr
1455                                        ? annotator->entity_data_schema()
1456                                        : nullptr)) {
1457     TC3_LOG(ERROR) << "Could not rank actions.";
1458     response.actions.clear();
1459   }
1460   return response;
1461 }
1462 
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1463 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1464     const Conversation& conversation,
1465     const ActionSuggestionOptions& options) const {
1466   return SuggestActions(conversation, /*annotator=*/nullptr, options);
1467 }
1468 
model() const1469 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1470 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1471   return entity_data_schema_;
1472 }
1473 
ViewActionsModel(const void * buffer,int size)1474 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1475   if (buffer == nullptr) {
1476     return nullptr;
1477   }
1478   return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1479 }
1480 
InitializeConversationIntentDetection(const std::string & serialized_config)1481 bool ActionsSuggestions::InitializeConversationIntentDetection(
1482     const std::string& serialized_config) {
1483   auto conversation_intent_detection =
1484       std::make_unique<ConversationIntentDetection>();
1485   if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
1486     TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
1487     return false;
1488   }
1489   conversation_intent_detection_ = std::move(conversation_intent_detection);
1490   return true;
1491 }
1492 
1493 }  // namespace libtextclassifier3
1494