1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <fstream>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 
24 #include "actions/actions_model_generated.h"
25 #include "actions/test-utils.h"
26 #include "actions/zlib-utils.h"
27 #include "annotator/collections.h"
28 #include "annotator/types.h"
29 #include "utils/flatbuffers/flatbuffers.h"
30 #include "utils/flatbuffers/flatbuffers_generated.h"
31 #include "utils/flatbuffers/mutable.h"
32 #include "utils/grammar/utils/locale-shard-map.h"
33 #include "utils/grammar/utils/rules.h"
34 #include "utils/hash/farmhash.h"
35 #include "utils/jvm-test-utils.h"
36 #include "utils/test-data-test-utils.h"
37 #include "gmock/gmock.h"
38 #include "gtest/gtest.h"
39 #include "flatbuffers/flatbuffers.h"
40 #include "flatbuffers/reflection.h"
41 
42 namespace libtextclassifier3 {
43 namespace {
44 
45 using ::testing::ElementsAre;
46 using ::testing::FloatEq;
47 using ::testing::IsEmpty;
48 using ::testing::NotNull;
49 using ::testing::SizeIs;
50 
51 constexpr char kModelFileName[] = "actions_suggestions_test.model";
52 constexpr char kModelGrammarFileName[] =
53     "actions_suggestions_grammar_test.model";
54 constexpr char kMultiTaskTF2TestModelFileName[] =
55     "actions_suggestions_test.multi_task_tf2_test.model";
56 constexpr char kMultiTaskModelFileName[] =
57     "actions_suggestions_test.multi_task_9heads.model";
58 constexpr char kHashGramModelFileName[] =
59     "actions_suggestions_test.hashgram.model";
60 constexpr char kMultiTaskSrP13nModelFileName[] =
61     "actions_suggestions_test.multi_task_sr_p13n.model";
62 constexpr char kMultiTaskSrEmojiModelFileName[] =
63     "actions_suggestions_test.multi_task_sr_emoji.model";
64 constexpr char kSensitiveTFliteModelFileName[] =
65     "actions_suggestions_test.sensitive_tflite.model";
66 
ReadFile(const std::string & file_name)67 std::string ReadFile(const std::string& file_name) {
68   std::ifstream file_stream(file_name);
69   return std::string(std::istreambuf_iterator<char>(file_stream), {});
70 }
71 
GetModelPath()72 std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); }
73 
74 class ActionsSuggestionsTest : public testing::Test {
75  protected:
ActionsSuggestionsTest()76   explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
LoadTestModel(const std::string model_file_name)77   std::unique_ptr<ActionsSuggestions> LoadTestModel(
78       const std::string model_file_name) {
79     return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
80                                         unilib_.get());
81   }
LoadHashGramTestModel()82   std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
83     return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
84                                         unilib_.get());
85   }
LoadMultiTaskTestModel()86   std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
87     return ActionsSuggestions::FromPath(
88         GetModelPath() + kMultiTaskModelFileName, unilib_.get());
89   }
90 
LoadMultiTaskSrP13nTestModel()91   std::unique_ptr<ActionsSuggestions> LoadMultiTaskSrP13nTestModel() {
92     return ActionsSuggestions::FromPath(
93         GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get());
94   }
95   std::unique_ptr<UniLib> unilib_;
96 };
97 
TEST_F(ActionsSuggestionsTest,InstantiateActionSuggestions)98 TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
99   EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
100 }
101 
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidInput)102 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
103   std::unique_ptr<ActionsSuggestions> actions_suggestions =
104       LoadTestModel(kModelFileName);
105   const ActionsSuggestionsResponse response =
106       actions_suggestions->SuggestActions(
107           {{{/*user_id=*/1, "Where are you?\xf0\x9f",
108              /*reference_time_ms_utc=*/0,
109              /*reference_timezone=*/"Europe/Zurich",
110              /*annotations=*/{}, /*locales=*/"en"}}});
111   EXPECT_THAT(response.actions, IsEmpty());
112 }
113 
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidUtf8)114 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
115   std::unique_ptr<ActionsSuggestions> actions_suggestions =
116       LoadTestModel(kModelFileName);
117 
118   const ActionsSuggestionsResponse response =
119       actions_suggestions->SuggestActions(
120           {{{/*user_id=*/1,
121              "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
122              /*reference_time_ms_utc=*/0,
123              /*reference_timezone=*/"Europe/Zurich",
124              /*annotations=*/{}, /*locales=*/"en"}}});
125   EXPECT_THAT(response.actions, IsEmpty());
126 }
127 
TEST_F(ActionsSuggestionsTest,SuggestsActions)128 TEST_F(ActionsSuggestionsTest, SuggestsActions) {
129   std::unique_ptr<ActionsSuggestions> actions_suggestions =
130       LoadTestModel(kModelFileName);
131   const ActionsSuggestionsResponse response =
132       actions_suggestions->SuggestActions(
133           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
134              /*reference_timezone=*/"Europe/Zurich",
135              /*annotations=*/{}, /*locales=*/"en"}}});
136   EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
137 }
138 
TEST_F(ActionsSuggestionsTest,SuggestsNoActionsForUnknownLocale)139 TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
140   std::unique_ptr<ActionsSuggestions> actions_suggestions =
141       LoadTestModel(kModelFileName);
142   const ActionsSuggestionsResponse response =
143       actions_suggestions->SuggestActions(
144           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
145              /*reference_timezone=*/"Europe/Zurich",
146              /*annotations=*/{}, /*locales=*/"zz"}}});
147   EXPECT_THAT(response.actions, testing::IsEmpty());
148 }
149 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotations)150 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
151   std::unique_ptr<ActionsSuggestions> actions_suggestions =
152       LoadTestModel(kModelFileName);
153   AnnotatedSpan annotation;
154   annotation.span = {11, 15};
155   annotation.classification = {ClassificationResult("address", 1.0)};
156   const ActionsSuggestionsResponse response =
157       actions_suggestions->SuggestActions(
158           {{{/*user_id=*/1, "are you at home?",
159              /*reference_time_ms_utc=*/0,
160              /*reference_timezone=*/"Europe/Zurich",
161              /*annotations=*/{annotation},
162              /*locales=*/"en"}}});
163   ASSERT_GE(response.actions.size(), 1);
164   EXPECT_EQ(response.actions.front().type, "view_map");
165   EXPECT_EQ(response.actions.front().score, 1.0);
166 }
167 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithEntityData)168 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) {
169   const std::string actions_model_string =
170       ReadFile(GetModelPath() + kModelFileName);
171   std::unique_ptr<ActionsModelT> actions_model =
172       UnPackActionsModel(actions_model_string.c_str());
173   SetTestEntityDataSchema(actions_model.get());
174 
175   // Set custom actions from annotations config.
176   actions_model->annotation_actions_spec->annotation_mapping.clear();
177   actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
178       new AnnotationActionsSpec_::AnnotationMappingT);
179   AnnotationActionsSpec_::AnnotationMappingT* mapping =
180       actions_model->annotation_actions_spec->annotation_mapping.back().get();
181   mapping->annotation_collection = "address";
182   mapping->action.reset(new ActionSuggestionSpecT);
183   mapping->action->type = "save_location";
184   mapping->action->score = 1.0;
185   mapping->action->priority_score = 2.0;
186   mapping->entity_field.reset(new FlatbufferFieldPathT);
187   mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
188   mapping->entity_field->field.back()->field_name = "location";
189 
190   flatbuffers::FlatBufferBuilder builder;
191   FinishActionsModelBuffer(builder,
192                            ActionsModel::Pack(builder, actions_model.get()));
193   std::unique_ptr<ActionsSuggestions> actions_suggestions =
194       ActionsSuggestions::FromUnownedBuffer(
195           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
196           builder.GetSize(), unilib_.get());
197 
198   AnnotatedSpan annotation;
199   annotation.span = {11, 15};
200   annotation.classification = {ClassificationResult("address", 1.0)};
201   const ActionsSuggestionsResponse response =
202       actions_suggestions->SuggestActions(
203           {{{/*user_id=*/1, "are you at home?",
204              /*reference_time_ms_utc=*/0,
205              /*reference_timezone=*/"Europe/Zurich",
206              /*annotations=*/{annotation},
207              /*locales=*/"en"}}});
208   ASSERT_GE(response.actions.size(), 1);
209   EXPECT_EQ(response.actions.front().type, "save_location");
210   EXPECT_EQ(response.actions.front().score, 1.0);
211 
212   // Check that the `location` entity field holds the text from the address
213   // annotation.
214   const flatbuffers::Table* entity =
215       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
216           response.actions.front().serialized_entity_data.data()));
217   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
218             "home");
219 }
220 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithNormalization)221 TEST_F(ActionsSuggestionsTest,
222        SuggestsActionsFromAnnotationsWithNormalization) {
223   const std::string actions_model_string =
224       ReadFile(GetModelPath() + kModelFileName);
225   std::unique_ptr<ActionsModelT> actions_model =
226       UnPackActionsModel(actions_model_string.c_str());
227   SetTestEntityDataSchema(actions_model.get());
228 
229   // Set custom actions from annotations config.
230   actions_model->annotation_actions_spec->annotation_mapping.clear();
231   actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
232       new AnnotationActionsSpec_::AnnotationMappingT);
233   AnnotationActionsSpec_::AnnotationMappingT* mapping =
234       actions_model->annotation_actions_spec->annotation_mapping.back().get();
235   mapping->annotation_collection = "address";
236   mapping->action.reset(new ActionSuggestionSpecT);
237   mapping->action->type = "save_location";
238   mapping->action->score = 1.0;
239   mapping->action->priority_score = 2.0;
240   mapping->entity_field.reset(new FlatbufferFieldPathT);
241   mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
242   mapping->entity_field->field.back()->field_name = "location";
243   mapping->normalization_options.reset(new NormalizationOptionsT);
244   mapping->normalization_options->codepointwise_normalization =
245       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
246 
247   flatbuffers::FlatBufferBuilder builder;
248   FinishActionsModelBuffer(builder,
249                            ActionsModel::Pack(builder, actions_model.get()));
250   std::unique_ptr<ActionsSuggestions> actions_suggestions =
251       ActionsSuggestions::FromUnownedBuffer(
252           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
253           builder.GetSize(), unilib_.get());
254 
255   AnnotatedSpan annotation;
256   annotation.span = {11, 15};
257   annotation.classification = {ClassificationResult("address", 1.0)};
258   const ActionsSuggestionsResponse response =
259       actions_suggestions->SuggestActions(
260           {{{/*user_id=*/1, "are you at home?",
261              /*reference_time_ms_utc=*/0,
262              /*reference_timezone=*/"Europe/Zurich",
263              /*annotations=*/{annotation},
264              /*locales=*/"en"}}});
265   ASSERT_GE(response.actions.size(), 1);
266   EXPECT_EQ(response.actions.front().type, "save_location");
267   EXPECT_EQ(response.actions.front().score, 1.0);
268 
269   // Check that the `location` entity field holds the normalized text of the
270   // annotation.
271   const flatbuffers::Table* entity =
272       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
273           response.actions.front().serialized_entity_data.data()));
274   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
275             "HOME");
276 }
277 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromDuplicatedAnnotations)278 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
279   std::unique_ptr<ActionsSuggestions> actions_suggestions =
280       LoadTestModel(kModelFileName);
281   AnnotatedSpan flight_annotation;
282   flight_annotation.span = {11, 15};
283   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
284   AnnotatedSpan flight_annotation2;
285   flight_annotation2.span = {35, 39};
286   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
287   AnnotatedSpan email_annotation;
288   email_annotation.span = {43, 56};
289   email_annotation.classification = {ClassificationResult("email", 2.0)};
290 
291   const ActionsSuggestionsResponse response =
292       actions_suggestions->SuggestActions(
293           {{{/*user_id=*/1,
294              "call me at LX38 or send message to LX38 or test@test.com.",
295              /*reference_time_ms_utc=*/0,
296              /*reference_timezone=*/"Europe/Zurich",
297              /*annotations=*/
298              {flight_annotation, flight_annotation2, email_annotation},
299              /*locales=*/"en"}}});
300 
301   ASSERT_GE(response.actions.size(), 2);
302   EXPECT_EQ(response.actions[0].type, "track_flight");
303   EXPECT_EQ(response.actions[0].score, 3.0);
304   EXPECT_EQ(response.actions[1].type, "send_email");
305   EXPECT_EQ(response.actions[1].score, 2.0);
306 }
307 
TEST_F(ActionsSuggestionsTest,SuggestsActionsAnnotationsWithNoDeduplication)308 TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) {
309   const std::string actions_model_string =
310       ReadFile(GetModelPath() + kModelFileName);
311   std::unique_ptr<ActionsModelT> actions_model =
312       UnPackActionsModel(actions_model_string.c_str());
313   // Disable deduplication.
314   actions_model->annotation_actions_spec->deduplicate_annotations = false;
315   flatbuffers::FlatBufferBuilder builder;
316   FinishActionsModelBuffer(builder,
317                            ActionsModel::Pack(builder, actions_model.get()));
318   std::unique_ptr<ActionsSuggestions> actions_suggestions =
319       ActionsSuggestions::FromUnownedBuffer(
320           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
321           builder.GetSize(), unilib_.get());
322   AnnotatedSpan flight_annotation;
323   flight_annotation.span = {11, 15};
324   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
325   AnnotatedSpan flight_annotation2;
326   flight_annotation2.span = {35, 39};
327   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
328   AnnotatedSpan email_annotation;
329   email_annotation.span = {43, 56};
330   email_annotation.classification = {ClassificationResult("email", 2.0)};
331 
332   const ActionsSuggestionsResponse response =
333       actions_suggestions->SuggestActions(
334           {{{/*user_id=*/1,
335              "call me at LX38 or send message to LX38 or test@test.com.",
336              /*reference_time_ms_utc=*/0,
337              /*reference_timezone=*/"Europe/Zurich",
338              /*annotations=*/
339              {flight_annotation, flight_annotation2, email_annotation},
340              /*locales=*/"en"}}});
341 
342   ASSERT_GE(response.actions.size(), 3);
343   EXPECT_EQ(response.actions[0].type, "track_flight");
344   EXPECT_EQ(response.actions[0].score, 3.0);
345   EXPECT_EQ(response.actions[1].type, "track_flight");
346   EXPECT_EQ(response.actions[1].score, 2.5);
347   EXPECT_EQ(response.actions[2].type, "send_email");
348   EXPECT_EQ(response.actions[2].score, 2.0);
349 }
350 
TestSuggestActionsFromAnnotations(const std::function<void (ActionsModelT *)> & set_config_fn,const UniLib * unilib=nullptr)351 ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
352     const std::function<void(ActionsModelT*)>& set_config_fn,
353     const UniLib* unilib = nullptr) {
354   const std::string actions_model_string =
355       ReadFile(GetModelPath() + kModelFileName);
356   std::unique_ptr<ActionsModelT> actions_model =
357       UnPackActionsModel(actions_model_string.c_str());
358 
359   // Set custom config.
360   set_config_fn(actions_model.get());
361 
362   // Disable smart reply for easier testing.
363   actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
364 
365   flatbuffers::FlatBufferBuilder builder;
366   FinishActionsModelBuffer(builder,
367                            ActionsModel::Pack(builder, actions_model.get()));
368   std::unique_ptr<ActionsSuggestions> actions_suggestions =
369       ActionsSuggestions::FromUnownedBuffer(
370           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
371           builder.GetSize(), unilib);
372 
373   AnnotatedSpan flight_annotation;
374   flight_annotation.span = {15, 19};
375   flight_annotation.classification = {ClassificationResult("flight", 2.0)};
376   AnnotatedSpan email_annotation;
377   email_annotation.span = {0, 16};
378   email_annotation.classification = {ClassificationResult("email", 1.0)};
379 
380   return actions_suggestions->SuggestActions(
381       {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
382          "hehe@android.com",
383          /*reference_time_ms_utc=*/0,
384          /*reference_timezone=*/"Europe/Zurich",
385          /*annotations=*/
386          {email_annotation},
387          /*locales=*/"en"},
388         {/*user_id=*/2,
389          "yoyo@android.com",
390          /*reference_time_ms_utc=*/0,
391          /*reference_timezone=*/"Europe/Zurich",
392          /*annotations=*/
393          {email_annotation},
394          /*locales=*/"en"},
395         {/*user_id=*/1,
396          "test@android.com",
397          /*reference_time_ms_utc=*/0,
398          /*reference_timezone=*/"Europe/Zurich",
399          /*annotations=*/
400          {email_annotation},
401          /*locales=*/"en"},
402         {/*user_id=*/1,
403          "I am on flight LX38.",
404          /*reference_time_ms_utc=*/0,
405          /*reference_timezone=*/"Europe/Zurich",
406          /*annotations=*/
407          {flight_annotation},
408          /*locales=*/"en"}}});
409 }
410 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastMessage)411 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) {
412   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
413       [](ActionsModelT* actions_model) {
414         actions_model->annotation_actions_spec->include_local_user_messages =
415             false;
416         actions_model->annotation_actions_spec->only_until_last_sent = true;
417         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
418         actions_model->annotation_actions_spec->max_history_from_last_person =
419             1;
420       },
421       unilib_.get());
422   EXPECT_THAT(response.actions, SizeIs(1));
423   EXPECT_EQ(response.actions[0].type, "track_flight");
424 }
425 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastPerson)426 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) {
427   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
428       [](ActionsModelT* actions_model) {
429         actions_model->annotation_actions_spec->include_local_user_messages =
430             false;
431         actions_model->annotation_actions_spec->only_until_last_sent = true;
432         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
433         actions_model->annotation_actions_spec->max_history_from_last_person =
434             3;
435       },
436       unilib_.get());
437   EXPECT_THAT(response.actions, SizeIs(2));
438   EXPECT_EQ(response.actions[0].type, "track_flight");
439   EXPECT_EQ(response.actions[1].type, "send_email");
440 }
441 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAny)442 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) {
443   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
444       [](ActionsModelT* actions_model) {
445         actions_model->annotation_actions_spec->include_local_user_messages =
446             false;
447         actions_model->annotation_actions_spec->only_until_last_sent = true;
448         actions_model->annotation_actions_spec->max_history_from_any_person = 2;
449         actions_model->annotation_actions_spec->max_history_from_last_person =
450             1;
451       },
452       unilib_.get());
453   EXPECT_THAT(response.actions, SizeIs(2));
454   EXPECT_EQ(response.actions[0].type, "track_flight");
455   EXPECT_EQ(response.actions[1].type, "send_email");
456 }
457 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessages)458 TEST_F(ActionsSuggestionsTest,
459        SuggestsActionsWithAnnotationsFromAnyManyMessages) {
460   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
461       [](ActionsModelT* actions_model) {
462         actions_model->annotation_actions_spec->include_local_user_messages =
463             false;
464         actions_model->annotation_actions_spec->only_until_last_sent = true;
465         actions_model->annotation_actions_spec->max_history_from_any_person = 3;
466         actions_model->annotation_actions_spec->max_history_from_last_person =
467             1;
468       },
469       unilib_.get());
470   EXPECT_THAT(response.actions, SizeIs(3));
471   EXPECT_EQ(response.actions[0].type, "track_flight");
472   EXPECT_EQ(response.actions[1].type, "send_email");
473   EXPECT_EQ(response.actions[2].type, "send_email");
474 }
475 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser)476 TEST_F(ActionsSuggestionsTest,
477        SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
478   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
479       [](ActionsModelT* actions_model) {
480         actions_model->annotation_actions_spec->include_local_user_messages =
481             false;
482         actions_model->annotation_actions_spec->only_until_last_sent = true;
483         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
484         actions_model->annotation_actions_spec->max_history_from_last_person =
485             1;
486       },
487       unilib_.get());
488   EXPECT_THAT(response.actions, SizeIs(3));
489   EXPECT_EQ(response.actions[0].type, "track_flight");
490   EXPECT_EQ(response.actions[1].type, "send_email");
491   EXPECT_EQ(response.actions[2].type, "send_email");
492 }
493 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser)494 TEST_F(ActionsSuggestionsTest,
495        SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
496   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
497       [](ActionsModelT* actions_model) {
498         actions_model->annotation_actions_spec->include_local_user_messages =
499             true;
500         actions_model->annotation_actions_spec->only_until_last_sent = false;
501         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
502         actions_model->annotation_actions_spec->max_history_from_last_person =
503             1;
504       },
505       unilib_.get());
506   EXPECT_THAT(response.actions, SizeIs(4));
507   EXPECT_EQ(response.actions[0].type, "track_flight");
508   EXPECT_EQ(response.actions[1].type, "send_email");
509   EXPECT_EQ(response.actions[2].type, "send_email");
510   EXPECT_EQ(response.actions[3].type, "send_email");
511 }
512 
TestSuggestActionsWithThreshold(const std::function<void (ActionsModelT *)> & set_value_fn,const UniLib * unilib=nullptr,const int expected_size=0,const std::string & preconditions_overwrite="")513 void TestSuggestActionsWithThreshold(
514     const std::function<void(ActionsModelT*)>& set_value_fn,
515     const UniLib* unilib = nullptr, const int expected_size = 0,
516     const std::string& preconditions_overwrite = "") {
517   const std::string actions_model_string =
518       ReadFile(GetModelPath() + kModelFileName);
519   std::unique_ptr<ActionsModelT> actions_model =
520       UnPackActionsModel(actions_model_string.c_str());
521   set_value_fn(actions_model.get());
522   flatbuffers::FlatBufferBuilder builder;
523   FinishActionsModelBuffer(builder,
524                            ActionsModel::Pack(builder, actions_model.get()));
525   std::unique_ptr<ActionsSuggestions> actions_suggestions =
526       ActionsSuggestions::FromUnownedBuffer(
527           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
528           builder.GetSize(), unilib, preconditions_overwrite);
529   ASSERT_TRUE(actions_suggestions);
530   const ActionsSuggestionsResponse response =
531       actions_suggestions->SuggestActions(
532           {{{/*user_id=*/1, "I have the low-ground. Where are you?",
533              /*reference_time_ms_utc=*/0,
534              /*reference_timezone=*/"Europe/Zurich",
535              /*annotations=*/{}, /*locales=*/"en"}}});
536   EXPECT_LE(response.actions.size(), expected_size);
537 }
538 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithTriggeringScore)539 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) {
540   TestSuggestActionsWithThreshold(
541       [](ActionsModelT* actions_model) {
542         actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
543       },
544       unilib_.get(),
545       /*expected_size=*/1 /*no smart reply, only actions*/
546   );
547 }
548 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinReplyScore)549 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) {
550   TestSuggestActionsWithThreshold(
551       [](ActionsModelT* actions_model) {
552         actions_model->preconditions->min_reply_score_threshold = 1.0;
553       },
554       unilib_.get(),
555       /*expected_size=*/1 /*no smart reply, only actions*/
556   );
557 }
558 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithSensitiveTopicScore)559 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) {
560   TestSuggestActionsWithThreshold(
561       [](ActionsModelT* actions_model) {
562         actions_model->preconditions->max_sensitive_topic_score = 0.0;
563       },
564       unilib_.get(),
565       /*expected_size=*/4 /* no sensitive prediction in test model*/);
566 }
567 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMaxInputLength)568 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) {
569   TestSuggestActionsWithThreshold(
570       [](ActionsModelT* actions_model) {
571         actions_model->preconditions->max_input_length = 0;
572       },
573       unilib_.get());
574 }
575 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinInputLength)576 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
577   TestSuggestActionsWithThreshold(
578       [](ActionsModelT* actions_model) {
579         actions_model->preconditions->min_input_length = 100;
580       },
581       unilib_.get());
582 }
583 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithPreconditionsOverwrite)584 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
585   TriggeringPreconditionsT preconditions_overwrite;
586   preconditions_overwrite.max_input_length = 0;
587   flatbuffers::FlatBufferBuilder builder;
588   builder.Finish(
589       TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
590   TestSuggestActionsWithThreshold(
591       // Keep model untouched.
592       [](ActionsModelT* actions_model) {}, unilib_.get(),
593       /*expected_size=*/0,
594       std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
595                   builder.GetSize()));
596 }
597 
598 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidence)599 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) {
600   TestSuggestActionsWithThreshold(
601       [](ActionsModelT* actions_model) {
602         actions_model->preconditions->suppress_on_low_confidence_input = true;
603         actions_model->low_confidence_rules.reset(new RulesModelT);
604         actions_model->low_confidence_rules->regex_rule.emplace_back(
605             new RulesModel_::RegexRuleT);
606         actions_model->low_confidence_rules->regex_rule.back()->pattern =
607             "low-ground";
608       },
609       unilib_.get());
610 }
611 
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutput)612 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
613   const std::string actions_model_string =
614       ReadFile(GetModelPath() + kModelFileName);
615   std::unique_ptr<ActionsModelT> actions_model =
616       UnPackActionsModel(actions_model_string.c_str());
617   // Add custom triggering rule.
618   actions_model->rules.reset(new RulesModelT());
619   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
620   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
621   rule->pattern = "^(?i:hello\\s(there))$";
622   {
623     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
624         new RulesModel_::RuleActionSpecT);
625     rule_action->action.reset(new ActionSuggestionSpecT);
626     rule_action->action->type = "text_reply";
627     rule_action->action->response_text = "General Desaster!";
628     rule_action->action->score = 1.0f;
629     rule_action->action->priority_score = 1.0f;
630     rule->actions.push_back(std::move(rule_action));
631   }
632   {
633     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
634         new RulesModel_::RuleActionSpecT);
635     rule_action->action.reset(new ActionSuggestionSpecT);
636     rule_action->action->type = "text_reply";
637     rule_action->action->response_text = "General Kenobi!";
638     rule_action->action->score = 1.0f;
639     rule_action->action->priority_score = 1.0f;
640     rule->actions.push_back(std::move(rule_action));
641   }
642 
643   // Add input-output low confidence rule.
644   actions_model->preconditions->suppress_on_low_confidence_input = true;
645   actions_model->low_confidence_rules.reset(new RulesModelT);
646   actions_model->low_confidence_rules->regex_rule.emplace_back(
647       new RulesModel_::RegexRuleT);
648   actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
649   actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
650       "(?i:desaster)";
651 
652   flatbuffers::FlatBufferBuilder builder;
653   FinishActionsModelBuffer(builder,
654                            ActionsModel::Pack(builder, actions_model.get()));
655   std::unique_ptr<ActionsSuggestions> actions_suggestions =
656       ActionsSuggestions::FromUnownedBuffer(
657           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
658           builder.GetSize(), unilib_.get());
659   ASSERT_TRUE(actions_suggestions);
660   const ActionsSuggestionsResponse response =
661       actions_suggestions->SuggestActions(
662           {{{/*user_id=*/1, "hello there",
663              /*reference_time_ms_utc=*/0,
664              /*reference_timezone=*/"Europe/Zurich",
665              /*annotations=*/{}, /*locales=*/"en"}}});
666   ASSERT_GE(response.actions.size(), 1);
667   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
668 }
669 
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutputOverwrite)670 TEST_F(ActionsSuggestionsTest,
671        SuggestsActionsLowConfidenceInputOutputOverwrite) {
672   const std::string actions_model_string =
673       ReadFile(GetModelPath() + kModelFileName);
674   std::unique_ptr<ActionsModelT> actions_model =
675       UnPackActionsModel(actions_model_string.c_str());
676   actions_model->low_confidence_rules.reset();
677 
678   // Add custom triggering rule.
679   actions_model->rules.reset(new RulesModelT());
680   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
681   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
682   rule->pattern = "^(?i:hello\\s(there))$";
683   {
684     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
685         new RulesModel_::RuleActionSpecT);
686     rule_action->action.reset(new ActionSuggestionSpecT);
687     rule_action->action->type = "text_reply";
688     rule_action->action->response_text = "General Desaster!";
689     rule_action->action->score = 1.0f;
690     rule_action->action->priority_score = 1.0f;
691     rule->actions.push_back(std::move(rule_action));
692   }
693   {
694     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
695         new RulesModel_::RuleActionSpecT);
696     rule_action->action.reset(new ActionSuggestionSpecT);
697     rule_action->action->type = "text_reply";
698     rule_action->action->response_text = "General Kenobi!";
699     rule_action->action->score = 1.0f;
700     rule_action->action->priority_score = 1.0f;
701     rule->actions.push_back(std::move(rule_action));
702   }
703 
704   // Add custom triggering rule via overwrite.
705   actions_model->preconditions->low_confidence_rules.reset();
706   TriggeringPreconditionsT preconditions;
707   preconditions.suppress_on_low_confidence_input = true;
708   preconditions.low_confidence_rules.reset(new RulesModelT);
709   preconditions.low_confidence_rules->regex_rule.emplace_back(
710       new RulesModel_::RegexRuleT);
711   preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
712   preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
713       "(?i:desaster)";
714   flatbuffers::FlatBufferBuilder preconditions_builder;
715   preconditions_builder.Finish(
716       TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
717   std::string serialize_preconditions = std::string(
718       reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
719       preconditions_builder.GetSize());
720 
721   flatbuffers::FlatBufferBuilder builder;
722   FinishActionsModelBuffer(builder,
723                            ActionsModel::Pack(builder, actions_model.get()));
724   std::unique_ptr<ActionsSuggestions> actions_suggestions =
725       ActionsSuggestions::FromUnownedBuffer(
726           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
727           builder.GetSize(), unilib_.get(), serialize_preconditions);
728 
729   ASSERT_TRUE(actions_suggestions);
730   const ActionsSuggestionsResponse response =
731       actions_suggestions->SuggestActions(
732           {{{/*user_id=*/1, "hello there",
733              /*reference_time_ms_utc=*/0,
734              /*reference_timezone=*/"Europe/Zurich",
735              /*annotations=*/{}, /*locales=*/"en"}}});
736   ASSERT_GE(response.actions.size(), 1);
737   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
738 }
739 #endif
740 
TEST_F(ActionsSuggestionsTest,SuppressActionsFromAnnotationsOnSensitiveTopic)741 TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
742   const std::string actions_model_string =
743       ReadFile(GetModelPath() + kModelFileName);
744   std::unique_ptr<ActionsModelT> actions_model =
745       UnPackActionsModel(actions_model_string.c_str());
746 
747   // Don't test if no sensitivity score is produced
748   if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
749     return;
750   }
751 
752   actions_model->preconditions->max_sensitive_topic_score = 0.0;
753   actions_model->preconditions->suppress_on_sensitive_topic = true;
754   flatbuffers::FlatBufferBuilder builder;
755   FinishActionsModelBuffer(builder,
756                            ActionsModel::Pack(builder, actions_model.get()));
757   std::unique_ptr<ActionsSuggestions> actions_suggestions =
758       ActionsSuggestions::FromUnownedBuffer(
759           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
760           builder.GetSize(), unilib_.get());
761   AnnotatedSpan annotation;
762   annotation.span = {11, 15};
763   annotation.classification = {
764       ClassificationResult(Collections::Address(), 1.0)};
765   const ActionsSuggestionsResponse response =
766       actions_suggestions->SuggestActions(
767           {{{/*user_id=*/1, "are you at home?",
768              /*reference_time_ms_utc=*/0,
769              /*reference_timezone=*/"Europe/Zurich",
770              /*annotations=*/{annotation},
771              /*locales=*/"en"}}});
772   EXPECT_THAT(response.actions, testing::IsEmpty());
773 }
774 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithLongerConversation)775 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) {
776   const std::string actions_model_string =
777       ReadFile(GetModelPath() + kModelFileName);
778   std::unique_ptr<ActionsModelT> actions_model =
779       UnPackActionsModel(actions_model_string.c_str());
780 
781   // Allow a larger conversation context.
782   actions_model->max_conversation_history_length = 10;
783 
784   flatbuffers::FlatBufferBuilder builder;
785   FinishActionsModelBuffer(builder,
786                            ActionsModel::Pack(builder, actions_model.get()));
787   std::unique_ptr<ActionsSuggestions> actions_suggestions =
788       ActionsSuggestions::FromUnownedBuffer(
789           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
790           builder.GetSize(), unilib_.get());
791   AnnotatedSpan annotation;
792   annotation.span = {11, 15};
793   annotation.classification = {
794       ClassificationResult(Collections::Address(), 1.0)};
795   const ActionsSuggestionsResponse response =
796       actions_suggestions->SuggestActions(
797           {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
798              /*reference_time_ms_utc=*/10000,
799              /*reference_timezone=*/"Europe/Zurich",
800              /*annotations=*/{}, /*locales=*/"en"},
801             {/*user_id=*/1, "good! are you at home?",
802              /*reference_time_ms_utc=*/15000,
803              /*reference_timezone=*/"Europe/Zurich",
804              /*annotations=*/{annotation},
805              /*locales=*/"en"}}});
806   ASSERT_GE(response.actions.size(), 1);
807   EXPECT_EQ(response.actions[0].type, "view_map");
808   EXPECT_EQ(response.actions[0].score, 1.0);
809 }
810 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromTF2MultiTaskModel)811 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) {
812   std::unique_ptr<ActionsSuggestions> actions_suggestions =
813       LoadTestModel(kMultiTaskTF2TestModelFileName);
814   const ActionsSuggestionsResponse response =
815       actions_suggestions->SuggestActions(
816           {{{/*user_id=*/1, "Hello how are you",
817              /*reference_time_ms_utc=*/0,
818              /*reference_timezone=*/"Europe/Zurich",
819              /*annotations=*/{},
820              /*locales=*/"en"}}});
821   EXPECT_EQ(response.actions.size(), 4);
822   EXPECT_EQ(response.actions[0].response_text, "Okay");
823   EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION");
824   EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT");
825 }
826 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromPhoneGrammarAnnotations)827 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
828   std::unique_ptr<ActionsSuggestions> actions_suggestions =
829       LoadTestModel(kModelGrammarFileName);
830   AnnotatedSpan annotation;
831   annotation.span = {11, 15};
832   annotation.classification = {ClassificationResult("phone", 0.0)};
833   const ActionsSuggestionsResponse response =
834       actions_suggestions->SuggestActions(
835           {{{/*user_id=*/1, "Contact us at: *1234",
836              /*reference_time_ms_utc=*/0,
837              /*reference_timezone=*/"Europe/Zurich",
838              /*annotations=*/{annotation},
839              /*locales=*/"en"}}});
840   ASSERT_GE(response.actions.size(), 1);
841   EXPECT_EQ(response.actions.front().type, "call_phone");
842   EXPECT_EQ(response.actions.front().score, 0.0);
843   EXPECT_EQ(response.actions.front().priority_score, 0.0);
844   EXPECT_EQ(response.actions.front().annotations.size(), 1);
845   EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
846   EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
847 }
848 
TEST_F(ActionsSuggestionsTest,CreateActionsFromClassificationResult)849 TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
850   std::unique_ptr<ActionsSuggestions> actions_suggestions =
851       LoadTestModel(kModelFileName);
852   AnnotatedSpan annotation;
853   annotation.span = {8, 12};
854   annotation.classification = {
855       ClassificationResult(Collections::Flight(), 1.0)};
856 
857   const ActionsSuggestionsResponse response =
858       actions_suggestions->SuggestActions(
859           {{{/*user_id=*/1, "I'm on LX38?",
860              /*reference_time_ms_utc=*/0,
861              /*reference_timezone=*/"Europe/Zurich",
862              /*annotations=*/{annotation},
863              /*locales=*/"en"}}});
864 
865   ASSERT_GE(response.actions.size(), 2);
866   EXPECT_EQ(response.actions[0].type, "track_flight");
867   EXPECT_EQ(response.actions[0].score, 1.0);
868   EXPECT_THAT(response.actions[0].annotations, SizeIs(1));
869   EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
870   EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
871 }
872 
873 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,CreateActionsFromRules)874 TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
875   const std::string actions_model_string =
876       ReadFile(GetModelPath() + kModelFileName);
877   std::unique_ptr<ActionsModelT> actions_model =
878       UnPackActionsModel(actions_model_string.c_str());
879   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
880 
881   actions_model->rules.reset(new RulesModelT());
882   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
883   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
884   rule->pattern = "^(?i:hello\\s(there))$";
885   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
886   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
887   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
888   action->type = "text_reply";
889   action->response_text = "General Kenobi!";
890   action->score = 1.0f;
891   action->priority_score = 1.0f;
892 
893   // Set capturing groups for entity data.
894   rule->actions.back()->capturing_group.emplace_back(
895       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
896   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
897       rule->actions.back()->capturing_group.back().get();
898   greeting_group->group_id = 0;
899   greeting_group->entity_field.reset(new FlatbufferFieldPathT);
900   greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
901   greeting_group->entity_field->field.back()->field_name = "greeting";
902   rule->actions.back()->capturing_group.emplace_back(
903       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
904   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
905       rule->actions.back()->capturing_group.back().get();
906   location_group->group_id = 1;
907   location_group->entity_field.reset(new FlatbufferFieldPathT);
908   location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
909   location_group->entity_field->field.back()->field_name = "location";
910 
911   // Set test entity data schema.
912   SetTestEntityDataSchema(actions_model.get());
913 
914   // Use meta data to generate custom serialized entity data.
915   MutableFlatbufferBuilder entity_data_builder(
916       flatbuffers::GetRoot<reflection::Schema>(
917           actions_model->actions_entity_data_schema.data()));
918   std::unique_ptr<MutableFlatbuffer> entity_data =
919       entity_data_builder.NewRoot();
920   entity_data->Set("person", "Kenobi");
921   action->serialized_entity_data = entity_data->Serialize();
922 
923   flatbuffers::FlatBufferBuilder builder;
924   FinishActionsModelBuffer(builder,
925                            ActionsModel::Pack(builder, actions_model.get()));
926   std::unique_ptr<ActionsSuggestions> actions_suggestions =
927       ActionsSuggestions::FromUnownedBuffer(
928           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
929           builder.GetSize(), unilib_.get());
930 
931   const ActionsSuggestionsResponse response =
932       actions_suggestions->SuggestActions(
933           {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
934              /*reference_timezone=*/"Europe/Zurich",
935              /*annotations=*/{}, /*locales=*/"en"}}});
936   EXPECT_GE(response.actions.size(), 1);
937   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
938 
939   // Check entity data.
940   const flatbuffers::Table* entity =
941       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
942           response.actions[0].serialized_entity_data.data()));
943   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
944             "hello there");
945   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
946             "there");
947   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
948             "Kenobi");
949 }
950 
TEST_F(ActionsSuggestionsTest,CreateActionsFromRulesWithNormalization)951 TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
952   const std::string actions_model_string =
953       ReadFile(GetModelPath() + kModelFileName);
954   std::unique_ptr<ActionsModelT> actions_model =
955       UnPackActionsModel(actions_model_string.c_str());
956   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
957 
958   actions_model->rules.reset(new RulesModelT());
959   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
960   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
961   rule->pattern = "^(?i:hello\\sthere)$";
962   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
963   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
964   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
965   action->type = "text_reply";
966   action->response_text = "General Kenobi!";
967   action->score = 1.0f;
968   action->priority_score = 1.0f;
969 
970   // Set capturing groups for entity data.
971   rule->actions.back()->capturing_group.emplace_back(
972       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
973   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
974       rule->actions.back()->capturing_group.back().get();
975   greeting_group->group_id = 0;
976   greeting_group->entity_field.reset(new FlatbufferFieldPathT);
977   greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
978   greeting_group->entity_field->field.back()->field_name = "greeting";
979   greeting_group->normalization_options.reset(new NormalizationOptionsT);
980   greeting_group->normalization_options->codepointwise_normalization =
981       NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
982       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
983 
984   // Set test entity data schema.
985   SetTestEntityDataSchema(actions_model.get());
986 
987   flatbuffers::FlatBufferBuilder builder;
988   FinishActionsModelBuffer(builder,
989                            ActionsModel::Pack(builder, actions_model.get()));
990   std::unique_ptr<ActionsSuggestions> actions_suggestions =
991       ActionsSuggestions::FromUnownedBuffer(
992           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
993           builder.GetSize(), unilib_.get());
994 
995   const ActionsSuggestionsResponse response =
996       actions_suggestions->SuggestActions(
997           {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
998              /*reference_timezone=*/"Europe/Zurich",
999              /*annotations=*/{}, /*locales=*/"en"}}});
1000   EXPECT_GE(response.actions.size(), 1);
1001   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
1002 
1003   // Check entity data.
1004   const flatbuffers::Table* entity =
1005       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
1006           response.actions[0].serialized_entity_data.data()));
1007   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
1008             "HELLOTHERE");
1009 }
1010 
TEST_F(ActionsSuggestionsTest,CreatesTextRepliesFromRules)1011 TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
1012   const std::string actions_model_string =
1013       ReadFile(GetModelPath() + kModelFileName);
1014   std::unique_ptr<ActionsModelT> actions_model =
1015       UnPackActionsModel(actions_model_string.c_str());
1016   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1017 
1018   actions_model->rules.reset(new RulesModelT());
1019   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1020   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1021   rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
1022   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1023 
1024   // Set capturing groups for entity data.
1025   rule->actions.back()->capturing_group.emplace_back(
1026       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1027   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1028       rule->actions.back()->capturing_group.back().get();
1029   code_group->group_id = 1;
1030   code_group->text_reply.reset(new ActionSuggestionSpecT);
1031   code_group->text_reply->score = 1.0f;
1032   code_group->text_reply->priority_score = 1.0f;
1033   code_group->normalization_options.reset(new NormalizationOptionsT);
1034   code_group->normalization_options->codepointwise_normalization =
1035       NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
1036 
1037   flatbuffers::FlatBufferBuilder builder;
1038   FinishActionsModelBuffer(builder,
1039                            ActionsModel::Pack(builder, actions_model.get()));
1040   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1041       ActionsSuggestions::FromUnownedBuffer(
1042           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1043           builder.GetSize(), unilib_.get());
1044 
1045   const ActionsSuggestionsResponse response =
1046       actions_suggestions->SuggestActions(
1047           {{{/*user_id=*/1,
1048              "visit test.com or reply STOP to cancel your subscription",
1049              /*reference_time_ms_utc=*/0,
1050              /*reference_timezone=*/"Europe/Zurich",
1051              /*annotations=*/{}, /*locales=*/"en"}}});
1052   EXPECT_GE(response.actions.size(), 1);
1053   EXPECT_EQ(response.actions[0].response_text, "stop");
1054 }
1055 
TEST_F(ActionsSuggestionsTest,CreatesActionsFromGrammarRules)1056 TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
1057   const std::string actions_model_string =
1058       ReadFile(GetModelPath() + kModelFileName);
1059   std::unique_ptr<ActionsModelT> actions_model =
1060       UnPackActionsModel(actions_model_string.c_str());
1061   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1062 
1063   actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1064 
1065   // Set tokenizer options.
1066   RulesModel_::GrammarRulesT* action_grammar_rules =
1067       actions_model->rules->grammar_rules.get();
1068   action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1069   action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1070   action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1071       false;
1072 
1073   // Setup test rules.
1074   action_grammar_rules->rules.reset(new grammar::RulesSetT);
1075   grammar::LocaleShardMap locale_shard_map =
1076       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1077   grammar::Rules rules(locale_shard_map);
1078   rules.Add(
1079       "<knock>", {"<^>", "ventura", "!?", "<$>"},
1080       /*callback=*/
1081       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1082       /*callback_param=*/0);
1083   rules.Finalize().Serialize(/*include_debug_information=*/false,
1084                              action_grammar_rules->rules.get());
1085   action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1086   RulesModel_::RuleActionSpecT* actions_spec =
1087       action_grammar_rules->actions.back().get();
1088   actions_spec->action.reset(new ActionSuggestionSpecT);
1089   actions_spec->action->response_text = "Yes, Satan?";
1090   actions_spec->action->priority_score = 1.0;
1091   actions_spec->action->score = 1.0;
1092   actions_spec->action->type = "text_reply";
1093   action_grammar_rules->rule_match.emplace_back(
1094       new RulesModel_::GrammarRules_::RuleMatchT);
1095   action_grammar_rules->rule_match.back()->action_id.push_back(0);
1096 
1097   flatbuffers::FlatBufferBuilder builder;
1098   FinishActionsModelBuffer(builder,
1099                            ActionsModel::Pack(builder, actions_model.get()));
1100   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1101       ActionsSuggestions::FromUnownedBuffer(
1102           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1103           builder.GetSize(), unilib_.get());
1104 
1105   const ActionsSuggestionsResponse response =
1106       actions_suggestions->SuggestActions(
1107           {{{/*user_id=*/1, "Ventura!",
1108              /*reference_time_ms_utc=*/0,
1109              /*reference_timezone=*/"Europe/Zurich",
1110              /*annotations=*/{}, /*locales=*/"en"}}});
1111 
1112   EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
1113 }
1114 
1115 #if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
TEST_F(ActionsSuggestionsTest,CreatesActionsWithAnnotationsFromGrammarRules)1116 TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
1117   std::unique_ptr<Annotator> annotator =
1118       Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
1119   const std::string actions_model_string =
1120       ReadFile(GetModelPath() + kModelFileName);
1121   std::unique_ptr<ActionsModelT> actions_model =
1122       UnPackActionsModel(actions_model_string.c_str());
1123   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1124 
1125   actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1126 
1127   // Set tokenizer options.
1128   RulesModel_::GrammarRulesT* action_grammar_rules =
1129       actions_model->rules->grammar_rules.get();
1130   action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1131   action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1132   action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1133       false;
1134 
1135   // Setup test rules.
1136   action_grammar_rules->rules.reset(new grammar::RulesSetT);
1137   grammar::LocaleShardMap locale_shard_map =
1138       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1139   grammar::Rules rules(locale_shard_map);
1140   rules.Add(
1141       "<event>", {"it", "is", "at", "<time>"},
1142       /*callback=*/
1143       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1144       /*callback_param=*/0);
1145   rules.BindAnnotation("<time>", "time");
1146   rules.AddAnnotation("datetime");
1147   rules.Finalize().Serialize(/*include_debug_information=*/false,
1148                              action_grammar_rules->rules.get());
1149   action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1150   RulesModel_::RuleActionSpecT* actions_spec =
1151       action_grammar_rules->actions.back().get();
1152   actions_spec->action.reset(new ActionSuggestionSpecT);
1153   actions_spec->action->priority_score = 1.0;
1154   actions_spec->action->score = 1.0;
1155   actions_spec->action->type = "create_event";
1156   action_grammar_rules->rule_match.emplace_back(
1157       new RulesModel_::GrammarRules_::RuleMatchT);
1158   action_grammar_rules->rule_match.back()->action_id.push_back(0);
1159 
1160   flatbuffers::FlatBufferBuilder builder;
1161   FinishActionsModelBuffer(builder,
1162                            ActionsModel::Pack(builder, actions_model.get()));
1163   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1164       ActionsSuggestions::FromUnownedBuffer(
1165           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1166           builder.GetSize(), unilib_.get());
1167 
1168   const ActionsSuggestionsResponse response =
1169       actions_suggestions->SuggestActions(
1170           {{{/*user_id=*/1, "it is at 10:30",
1171              /*reference_time_ms_utc=*/0,
1172              /*reference_timezone=*/"Europe/Zurich",
1173              /*annotations=*/{}, /*locales=*/"en"}}},
1174           annotator.get());
1175 
1176   EXPECT_THAT(response.actions, ElementsAre(IsActionOfType("create_event")));
1177 }
1178 #endif
1179 
TEST_F(ActionsSuggestionsTest,DeduplicateActions)1180 TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
1181   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1182       LoadTestModel(kModelFileName);
1183   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1184       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1185          /*reference_timezone=*/"Europe/Zurich",
1186          /*annotations=*/{}, /*locales=*/"en"}}});
1187 
1188   // Check that the location sharing model triggered.
1189   bool has_location_sharing_action = false;
1190   for (const ActionSuggestion& action : response.actions) {
1191     if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
1192       has_location_sharing_action = true;
1193       break;
1194     }
1195   }
1196   EXPECT_TRUE(has_location_sharing_action);
1197   const int num_actions = response.actions.size();
1198 
1199   // Add custom rule for location sharing.
1200   const std::string actions_model_string =
1201       ReadFile(GetModelPath() + kModelFileName);
1202   std::unique_ptr<ActionsModelT> actions_model =
1203       UnPackActionsModel(actions_model_string.c_str());
1204   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1205 
1206   actions_model->rules.reset(new RulesModelT());
1207   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1208   actions_model->rules->regex_rule.back()->pattern =
1209       "^(?i:where are you[.?]?)$";
1210   actions_model->rules->regex_rule.back()->actions.emplace_back(
1211       new RulesModel_::RuleActionSpecT);
1212   actions_model->rules->regex_rule.back()->actions.back()->action.reset(
1213       new ActionSuggestionSpecT);
1214   ActionSuggestionSpecT* action =
1215       actions_model->rules->regex_rule.back()->actions.back()->action.get();
1216   action->score = 1.0f;
1217   action->type = ActionsSuggestionsTypes::ShareLocation();
1218 
1219   flatbuffers::FlatBufferBuilder builder;
1220   FinishActionsModelBuffer(builder,
1221                            ActionsModel::Pack(builder, actions_model.get()));
1222   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1223       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1224       builder.GetSize(), unilib_.get());
1225 
1226   response = actions_suggestions->SuggestActions(
1227       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1228          /*reference_timezone=*/"Europe/Zurich",
1229          /*annotations=*/{}, /*locales=*/"en"}}});
1230   EXPECT_THAT(response.actions, SizeIs(num_actions));
1231 }
1232 
TEST_F(ActionsSuggestionsTest,DeduplicateConflictingActions)1233 TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
1234   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1235       LoadTestModel(kModelFileName);
1236   AnnotatedSpan annotation;
1237   annotation.span = {7, 11};
1238   annotation.classification = {
1239       ClassificationResult(Collections::Flight(), 1.0)};
1240   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1241       {{{/*user_id=*/1, "I'm on LX38",
1242          /*reference_time_ms_utc=*/0,
1243          /*reference_timezone=*/"Europe/Zurich",
1244          /*annotations=*/{annotation},
1245          /*locales=*/"en"}}});
1246 
1247   // Check that the phone actions are present.
1248   EXPECT_GE(response.actions.size(), 1);
1249   EXPECT_EQ(response.actions[0].type, "track_flight");
1250 
1251   // Add custom rule.
1252   const std::string actions_model_string =
1253       ReadFile(GetModelPath() + kModelFileName);
1254   std::unique_ptr<ActionsModelT> actions_model =
1255       UnPackActionsModel(actions_model_string.c_str());
1256   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1257 
1258   actions_model->rules.reset(new RulesModelT());
1259   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1260   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1261   rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
1262   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1263   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
1264   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
1265   action->score = 1.0f;
1266   action->priority_score = 2.0f;
1267   action->type = "test_code";
1268   rule->actions.back()->capturing_group.emplace_back(
1269       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1270   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1271       rule->actions.back()->capturing_group.back().get();
1272   code_group->group_id = 1;
1273   code_group->annotation_name = "code";
1274   code_group->annotation_type = "code";
1275 
1276   flatbuffers::FlatBufferBuilder builder;
1277   FinishActionsModelBuffer(builder,
1278                            ActionsModel::Pack(builder, actions_model.get()));
1279   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1280       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1281       builder.GetSize(), unilib_.get());
1282 
1283   response = actions_suggestions->SuggestActions(
1284       {{{/*user_id=*/1, "I'm on LX38",
1285          /*reference_time_ms_utc=*/0,
1286          /*reference_timezone=*/"Europe/Zurich",
1287          /*annotations=*/{annotation},
1288          /*locales=*/"en"}}});
1289   EXPECT_GE(response.actions.size(), 1);
1290   EXPECT_EQ(response.actions[0].type, "test_code");
1291 }
1292 #endif
1293 
TEST_F(ActionsSuggestionsTest,RanksActions)1294 TEST_F(ActionsSuggestionsTest, RanksActions) {
1295   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1296       LoadTestModel(kModelFileName);
1297   std::vector<AnnotatedSpan> annotations(2);
1298   annotations[0].span = {11, 15};
1299   annotations[0].classification = {ClassificationResult("address", 1.0)};
1300   annotations[1].span = {19, 23};
1301   annotations[1].classification = {ClassificationResult("address", 2.0)};
1302   const ActionsSuggestionsResponse response =
1303       actions_suggestions->SuggestActions(
1304           {{{/*user_id=*/1, "are you at home or work?",
1305              /*reference_time_ms_utc=*/0,
1306              /*reference_timezone=*/"Europe/Zurich",
1307              /*annotations=*/annotations,
1308              /*locales=*/"en"}}});
1309   EXPECT_GE(response.actions.size(), 2);
1310   EXPECT_EQ(response.actions[0].type, "view_map");
1311   EXPECT_EQ(response.actions[0].score, 2.0);
1312   EXPECT_EQ(response.actions[1].type, "view_map");
1313   EXPECT_EQ(response.actions[1].score, 1.0);
1314 }
1315 
TEST_F(ActionsSuggestionsTest,VisitActionsModel)1316 TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
1317   EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
1318                                       [](const ActionsModel* model) {
1319                                         if (model == nullptr) {
1320                                           return false;
1321                                         }
1322                                         return true;
1323                                       }));
1324   EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
1325                                        [](const ActionsModel* model) {
1326                                          if (model == nullptr) {
1327                                            return false;
1328                                          }
1329                                          return true;
1330                                        }));
1331 }
1332 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithHashGramModel)1333 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithHashGramModel) {
1334   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1335       LoadHashGramTestModel();
1336   ASSERT_TRUE(actions_suggestions != nullptr);
1337   {
1338     const ActionsSuggestionsResponse response =
1339         actions_suggestions->SuggestActions(
1340             {{{/*user_id=*/1, "hello",
1341                /*reference_time_ms_utc=*/0,
1342                /*reference_timezone=*/"Europe/Zurich",
1343                /*annotations=*/{},
1344                /*locales=*/"en"}}});
1345     EXPECT_THAT(response.actions, testing::IsEmpty());
1346   }
1347   {
1348     const ActionsSuggestionsResponse response =
1349         actions_suggestions->SuggestActions(
1350             {{{/*user_id=*/1, "where are you",
1351                /*reference_time_ms_utc=*/0,
1352                /*reference_timezone=*/"Europe/Zurich",
1353                /*annotations=*/{},
1354                /*locales=*/"en"}}});
1355     EXPECT_THAT(
1356         response.actions,
1357         ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
1358   }
1359   {
1360     const ActionsSuggestionsResponse response =
1361         actions_suggestions->SuggestActions(
1362             {{{/*user_id=*/1, "do you know johns number",
1363                /*reference_time_ms_utc=*/0,
1364                /*reference_timezone=*/"Europe/Zurich",
1365                /*annotations=*/{},
1366                /*locales=*/"en"}}});
1367     EXPECT_THAT(
1368         response.actions,
1369         ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
1370   }
1371 }
1372 
1373 // Test class to expose token embedding methods for testing.
1374 class TestingMessageEmbedder : private ActionsSuggestions {
1375  public:
1376   explicit TestingMessageEmbedder(const ActionsModel* model);
1377 
1378   using ActionsSuggestions::EmbedAndFlattenTokens;
1379   using ActionsSuggestions::EmbedTokensPerMessage;
1380 
1381  protected:
1382   // EmbeddingExecutor that always returns features based on
1383   // the id of the sparse features.
1384   class FakeEmbeddingExecutor : public EmbeddingExecutor {
1385    public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const1386     bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
1387                       const int dest_size) const override {
1388       TC3_CHECK_GE(dest_size, 1);
1389       EXPECT_EQ(sparse_features.size(), 1);
1390       dest[0] = sparse_features.data()[0];
1391       return true;
1392     }
1393   };
1394 
1395   std::unique_ptr<UniLib> unilib_;
1396 };
1397 
TestingMessageEmbedder(const ActionsModel * model)1398 TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
1399     : unilib_(CreateUniLibForTesting()) {
1400   model_ = model;
1401   const ActionsTokenFeatureProcessorOptions* options =
1402       model->feature_processor_options();
1403   feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
1404   embedding_executor_.reset(new FakeEmbeddingExecutor());
1405   EXPECT_TRUE(
1406       EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
1407   EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
1408   EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
1409   token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
1410   EXPECT_EQ(token_embedding_size_, 1);
1411 }
1412 
1413 class EmbeddingTest : public testing::Test {
1414  protected:
EmbeddingTest()1415   explicit EmbeddingTest() {
1416     model_.feature_processor_options.reset(
1417         new ActionsTokenFeatureProcessorOptionsT);
1418     options_ = model_.feature_processor_options.get();
1419     options_->chargram_orders = {1};
1420     options_->num_buckets = 1000;
1421     options_->embedding_size = 1;
1422     options_->start_token_id = 0;
1423     options_->end_token_id = 1;
1424     options_->padding_token_id = 2;
1425     options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1426   }
1427 
CreateTestingMessageEmbedder()1428   TestingMessageEmbedder CreateTestingMessageEmbedder() {
1429     flatbuffers::FlatBufferBuilder builder;
1430     FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
1431     buffer_ = builder.Release();
1432     return TestingMessageEmbedder(
1433         flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
1434   }
1435 
1436   flatbuffers::DetachedBuffer buffer_;
1437   ActionsModelT model_;
1438   ActionsTokenFeatureProcessorOptionsT* options_;
1439 };
1440 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithNoBounds)1441 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
1442   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1443   std::vector<std::vector<Token>> tokens = {
1444       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1445   std::vector<float> embeddings;
1446   int max_num_tokens_per_message = 0;
1447 
1448   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1449                                              &max_num_tokens_per_message));
1450 
1451   EXPECT_EQ(max_num_tokens_per_message, 3);
1452   EXPECT_EQ(embeddings.size(), 3);
1453   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1454                                      options_->num_buckets));
1455   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1456                                      options_->num_buckets));
1457   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1458                                      options_->num_buckets));
1459 }
1460 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithPadding)1461 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
1462   options_->min_num_tokens_per_message = 5;
1463   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1464   std::vector<std::vector<Token>> tokens = {
1465       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1466   std::vector<float> embeddings;
1467   int max_num_tokens_per_message = 0;
1468 
1469   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1470                                              &max_num_tokens_per_message));
1471 
1472   EXPECT_EQ(max_num_tokens_per_message, 5);
1473   EXPECT_EQ(embeddings.size(), 5);
1474   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1475                                      options_->num_buckets));
1476   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1477                                      options_->num_buckets));
1478   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1479                                      options_->num_buckets));
1480   EXPECT_THAT(embeddings[3], FloatEq(options_->padding_token_id));
1481   EXPECT_THAT(embeddings[4], FloatEq(options_->padding_token_id));
1482 }
1483 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageDropsAtBeginning)1484 TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
1485   options_->max_num_tokens_per_message = 2;
1486   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1487   std::vector<std::vector<Token>> tokens = {
1488       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1489   std::vector<float> embeddings;
1490   int max_num_tokens_per_message = 0;
1491 
1492   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1493                                              &max_num_tokens_per_message));
1494 
1495   EXPECT_EQ(max_num_tokens_per_message, 2);
1496   EXPECT_EQ(embeddings.size(), 2);
1497   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1498                                      options_->num_buckets));
1499   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1500                                      options_->num_buckets));
1501 }
1502 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithMultipleMessagesNoBounds)1503 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
1504   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1505   std::vector<std::vector<Token>> tokens = {
1506       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1507       {Token("d", 0, 1), Token("e", 2, 3)}};
1508   std::vector<float> embeddings;
1509   int max_num_tokens_per_message = 0;
1510 
1511   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1512                                              &max_num_tokens_per_message));
1513 
1514   EXPECT_EQ(max_num_tokens_per_message, 3);
1515   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1516                                      options_->num_buckets));
1517   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1518                                      options_->num_buckets));
1519   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1520                                      options_->num_buckets));
1521   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1522                                      options_->num_buckets));
1523   EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1524                                      options_->num_buckets));
1525   EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1526 }
1527 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithNoBounds)1528 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
1529   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1530   std::vector<std::vector<Token>> tokens = {
1531       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1532   std::vector<float> embeddings;
1533   int total_token_count = 0;
1534 
1535   EXPECT_TRUE(
1536       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1537 
1538   EXPECT_EQ(total_token_count, 5);
1539   EXPECT_EQ(embeddings.size(), 5);
1540   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1541   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1542                                      options_->num_buckets));
1543   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1544                                      options_->num_buckets));
1545   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1546                                      options_->num_buckets));
1547   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1548 }
1549 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithPadding)1550 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
1551   options_->min_num_total_tokens = 7;
1552   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1553   std::vector<std::vector<Token>> tokens = {
1554       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1555   std::vector<float> embeddings;
1556   int total_token_count = 0;
1557 
1558   EXPECT_TRUE(
1559       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1560 
1561   EXPECT_EQ(total_token_count, 7);
1562   EXPECT_EQ(embeddings.size(), 7);
1563   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1564   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1565                                      options_->num_buckets));
1566   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1567                                      options_->num_buckets));
1568   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1569                                      options_->num_buckets));
1570   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1571   EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1572   EXPECT_THAT(embeddings[6], FloatEq(options_->padding_token_id));
1573 }
1574 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensDropsAtBeginning)1575 TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
1576   options_->max_num_total_tokens = 3;
1577   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1578   std::vector<std::vector<Token>> tokens = {
1579       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1580   std::vector<float> embeddings;
1581   int total_token_count = 0;
1582 
1583   EXPECT_TRUE(
1584       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1585 
1586   EXPECT_EQ(total_token_count, 3);
1587   EXPECT_EQ(embeddings.size(), 3);
1588   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1589                                      options_->num_buckets));
1590   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1591                                      options_->num_buckets));
1592   EXPECT_THAT(embeddings[2], FloatEq(options_->end_token_id));
1593 }
1594 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesNoBounds)1595 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
1596   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1597   std::vector<std::vector<Token>> tokens = {
1598       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1599       {Token("d", 0, 1), Token("e", 2, 3)}};
1600   std::vector<float> embeddings;
1601   int total_token_count = 0;
1602 
1603   EXPECT_TRUE(
1604       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1605 
1606   EXPECT_EQ(total_token_count, 9);
1607   EXPECT_EQ(embeddings.size(), 9);
1608   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1609   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1610                                      options_->num_buckets));
1611   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1612                                      options_->num_buckets));
1613   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1614                                      options_->num_buckets));
1615   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1616   EXPECT_THAT(embeddings[5], FloatEq(options_->start_token_id));
1617   EXPECT_THAT(embeddings[6], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1618                                      options_->num_buckets));
1619   EXPECT_THAT(embeddings[7], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1620                                      options_->num_buckets));
1621   EXPECT_THAT(embeddings[8], FloatEq(options_->end_token_id));
1622 }
1623 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning)1624 TEST_F(EmbeddingTest,
1625        EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
1626   options_->max_num_total_tokens = 7;
1627   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1628   std::vector<std::vector<Token>> tokens = {
1629       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1630       {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
1631   std::vector<float> embeddings;
1632   int total_token_count = 0;
1633 
1634   EXPECT_TRUE(
1635       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1636 
1637   EXPECT_EQ(total_token_count, 7);
1638   EXPECT_EQ(embeddings.size(), 7);
1639   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1640                                      options_->num_buckets));
1641   EXPECT_THAT(embeddings[1], FloatEq(options_->end_token_id));
1642   EXPECT_THAT(embeddings[2], FloatEq(options_->start_token_id));
1643   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1644                                      options_->num_buckets));
1645   EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1646                                      options_->num_buckets));
1647   EXPECT_THAT(embeddings[5], FloatEq(tc3farmhash::Fingerprint64("f", 1) %
1648                                      options_->num_buckets));
1649   EXPECT_THAT(embeddings[6], FloatEq(options_->end_token_id));
1650 }
1651 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDefault)1652 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsDefault) {
1653   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1654       LoadMultiTaskTestModel();
1655   const ActionsSuggestionsResponse response =
1656       actions_suggestions->SuggestActions(
1657           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1658              /*reference_timezone=*/"Europe/Zurich",
1659              /*annotations=*/{}, /*locales=*/"en"}}});
1660   EXPECT_EQ(response.actions.size(),
1661             11 /* 8 binary classification + 3 smart replies*/);
1662 }
1663 
1664 const float kDisableThresholdVal = 2.0;
1665 
1666 constexpr char kSpamThreshold[] = "spam_confidence_threshold";
1667 constexpr char kLocationThreshold[] = "location_confidence_threshold";
1668 constexpr char kPhoneThreshold[] = "phone_confidence_threshold";
1669 constexpr char kWeatherThreshold[] = "weather_confidence_threshold";
1670 constexpr char kRestaurantsThreshold[] = "restaurants_confidence_threshold";
1671 constexpr char kMoviesThreshold[] = "movies_confidence_threshold";
1672 constexpr char kTtrThreshold[] = "time_to_reply_binary_threshold";
1673 constexpr char kReminderThreshold[] = "reminder_intent_confidence_threshold";
1674 constexpr char kDiversificationParm[] = "diversification_distance_threshold";
1675 constexpr char kEmpiricalProbFactor[] = "empirical_probability_factor";
1676 
GetOptionsToDisableAllClassification()1677 ActionSuggestionOptions GetOptionsToDisableAllClassification() {
1678   ActionSuggestionOptions options;
1679   // Disable all classification heads.
1680   options.model_parameters.insert(
1681       {kSpamThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1682   options.model_parameters.insert(
1683       {kLocationThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1684   options.model_parameters.insert(
1685       {kPhoneThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1686   options.model_parameters.insert(
1687       {kWeatherThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1688   options.model_parameters.insert(
1689       {kRestaurantsThreshold,
1690        libtextclassifier3::Variant(kDisableThresholdVal)});
1691   options.model_parameters.insert(
1692       {kMoviesThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1693   options.model_parameters.insert(
1694       {kTtrThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1695   options.model_parameters.insert(
1696       {kReminderThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1697   return options;
1698 }
1699 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyOnly)1700 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyOnly) {
1701   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1702       LoadMultiTaskTestModel();
1703   const ActionSuggestionOptions options =
1704       GetOptionsToDisableAllClassification();
1705   const ActionsSuggestionsResponse response =
1706       actions_suggestions->SuggestActions(
1707           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1708              /*reference_timezone=*/"Europe/Zurich",
1709              /*annotations=*/{}, /*locales=*/"en"}}},
1710           /*annotator=*/nullptr, options);
1711   EXPECT_THAT(response.actions,
1712               ElementsAre(IsSmartReply("Here"), IsSmartReply("I'm here"),
1713                           IsSmartReply("I'm home")));
1714   EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1715 }
1716 
1717 const int kUserProfileSize = 1000;
1718 constexpr char kUserProfileTokenIndex[] = "user_profile_token_index";
1719 constexpr char kUserProfileTokenWeight[] = "user_profile_token_weight";
1720 
GetOptionsForSmartReplyP13nModel()1721 ActionSuggestionOptions GetOptionsForSmartReplyP13nModel() {
1722   ActionSuggestionOptions options;
1723   const std::vector<int> user_profile_token_indexes(kUserProfileSize, 1);
1724   const std::vector<float> user_profile_token_weights(kUserProfileSize, 0.1f);
1725   options.model_parameters.insert(
1726       {kUserProfileTokenIndex,
1727        libtextclassifier3::Variant(user_profile_token_indexes)});
1728   options.model_parameters.insert(
1729       {kUserProfileTokenWeight,
1730        libtextclassifier3::Variant(user_profile_token_weights)});
1731   return options;
1732 }
1733 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyP13n)1734 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyP13n) {
1735   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1736       LoadMultiTaskSrP13nTestModel();
1737   const ActionSuggestionOptions options = GetOptionsForSmartReplyP13nModel();
1738   const ActionsSuggestionsResponse response =
1739       actions_suggestions->SuggestActions(
1740           {{{/*user_id=*/1, "How are you?", /*reference_time_ms_utc=*/0,
1741              /*reference_timezone=*/"Europe/Zurich",
1742              /*annotations=*/{}, /*locales=*/"en"}}},
1743           /*annotator=*/nullptr, options);
1744   EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1745 }
1746 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation)1747 TEST_F(ActionsSuggestionsTest,
1748        MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation) {
1749   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1750       LoadMultiTaskTestModel();
1751   ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1752   options.model_parameters[kLocationThreshold] =
1753       libtextclassifier3::Variant(0.35f);
1754   options.model_parameters.insert(
1755       {kDiversificationParm, libtextclassifier3::Variant(0.5f)});
1756   const ActionsSuggestionsResponse response =
1757       actions_suggestions->SuggestActions(
1758           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1759              /*reference_timezone=*/"Europe/Zurich",
1760              /*annotations=*/{}, /*locales=*/"en"}}},
1761           /*annotator=*/nullptr, options);
1762   EXPECT_THAT(
1763       response.actions,
1764       ElementsAre(IsActionOfType("LOCATION_SHARE"), IsSmartReply("Here"),
1765                   IsSmartReply("Yes"), IsSmartReply("��")));
1766   EXPECT_EQ(response.actions.size(), 4 /*1 location share + 3 smart replies*/);
1767 }
1768 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder)1769 TEST_F(ActionsSuggestionsTest,
1770        MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder) {
1771   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1772       LoadMultiTaskTestModel();
1773   ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1774   options.model_parameters[kLocationThreshold] =
1775       libtextclassifier3::Variant(0.35f);
1776   // reminder head always trigger since the threshold is zero.
1777   options.model_parameters[kReminderThreshold] =
1778       libtextclassifier3::Variant(0.0f);
1779   options.model_parameters.insert(
1780       {kEmpiricalProbFactor, libtextclassifier3::Variant(2.0f)});
1781   const ActionsSuggestionsResponse response =
1782       actions_suggestions->SuggestActions(
1783           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1784              /*reference_timezone=*/"Europe/Zurich",
1785              /*annotations=*/{}, /*locales=*/"en"}}},
1786           /*annotator=*/nullptr, options);
1787   EXPECT_THAT(
1788       response.actions,
1789       ElementsAre(IsSmartReply("Okay"), IsActionOfType("LOCATION_SHARE"),
1790                   IsSmartReply("Yes"),
1791                   /*Different emoji than previous test*/ IsSmartReply("��"),
1792                   IsActionOfType("REMINDER_INTENT")));
1793   EXPECT_EQ(response.actions.size(), 5 /*1 location share + 3 smart replies*/);
1794 }
1795 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromMultiTaskSrEmojiModel)1796 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
1797   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1798       LoadTestModel(kMultiTaskSrEmojiModelFileName);
1799   const ActionsSuggestionsResponse response =
1800       actions_suggestions->SuggestActions(
1801           {{{/*user_id=*/1, "hello?",
1802              /*reference_time_ms_utc=*/0,
1803              /*reference_timezone=*/"Europe/Zurich",
1804              /*annotations=*/{},
1805              /*locales=*/"en"}}});
1806   EXPECT_EQ(response.actions.size(), 5);
1807   EXPECT_EQ(response.actions[0].response_text, "��");
1808   EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT");
1809   EXPECT_EQ(response.actions[1].response_text, "Yes");
1810   EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
1811 }
1812 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromSensitiveTfLiteModel)1813 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
1814   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1815       LoadTestModel(kSensitiveTFliteModelFileName);
1816   const ActionsSuggestionsResponse response =
1817       actions_suggestions->SuggestActions(
1818           {{{/*user_id=*/1, "I want to kill myself",
1819              /*reference_time_ms_utc=*/0,
1820              /*reference_timezone=*/"Europe/Zurich",
1821              /*annotations=*/{},
1822              /*locales=*/"en"}}});
1823   EXPECT_EQ(response.actions.size(), 0);
1824   EXPECT_TRUE(response.is_sensitive);
1825   EXPECT_FALSE(response.output_filtered_low_confidence);
1826 }
1827 
1828 }  // namespace
1829 }  // namespace libtextclassifier3
1830