1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <fstream>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23
24 #include "actions/actions_model_generated.h"
25 #include "actions/test-utils.h"
26 #include "actions/zlib-utils.h"
27 #include "annotator/collections.h"
28 #include "annotator/types.h"
29 #include "utils/flatbuffers/flatbuffers.h"
30 #include "utils/flatbuffers/flatbuffers_generated.h"
31 #include "utils/flatbuffers/mutable.h"
32 #include "utils/grammar/utils/locale-shard-map.h"
33 #include "utils/grammar/utils/rules.h"
34 #include "utils/hash/farmhash.h"
35 #include "utils/jvm-test-utils.h"
36 #include "utils/test-data-test-utils.h"
37 #include "gmock/gmock.h"
38 #include "gtest/gtest.h"
39 #include "flatbuffers/flatbuffers.h"
40 #include "flatbuffers/reflection.h"
41
42 namespace libtextclassifier3 {
43 namespace {
44
45 using ::testing::ElementsAre;
46 using ::testing::FloatEq;
47 using ::testing::IsEmpty;
48 using ::testing::NotNull;
49 using ::testing::SizeIs;
50
51 constexpr char kModelFileName[] = "actions_suggestions_test.model";
52 constexpr char kModelGrammarFileName[] =
53 "actions_suggestions_grammar_test.model";
54 constexpr char kMultiTaskTF2TestModelFileName[] =
55 "actions_suggestions_test.multi_task_tf2_test.model";
56 constexpr char kMultiTaskModelFileName[] =
57 "actions_suggestions_test.multi_task_9heads.model";
58 constexpr char kHashGramModelFileName[] =
59 "actions_suggestions_test.hashgram.model";
60 constexpr char kMultiTaskSrP13nModelFileName[] =
61 "actions_suggestions_test.multi_task_sr_p13n.model";
62 constexpr char kMultiTaskSrEmojiModelFileName[] =
63 "actions_suggestions_test.multi_task_sr_emoji.model";
64 constexpr char kSensitiveTFliteModelFileName[] =
65 "actions_suggestions_test.sensitive_tflite.model";
66
ReadFile(const std::string & file_name)67 std::string ReadFile(const std::string& file_name) {
68 std::ifstream file_stream(file_name);
69 return std::string(std::istreambuf_iterator<char>(file_stream), {});
70 }
71
GetModelPath()72 std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); }
73
74 class ActionsSuggestionsTest : public testing::Test {
75 protected:
ActionsSuggestionsTest()76 explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
LoadTestModel(const std::string model_file_name)77 std::unique_ptr<ActionsSuggestions> LoadTestModel(
78 const std::string model_file_name) {
79 return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
80 unilib_.get());
81 }
LoadHashGramTestModel()82 std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
83 return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
84 unilib_.get());
85 }
LoadMultiTaskTestModel()86 std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
87 return ActionsSuggestions::FromPath(
88 GetModelPath() + kMultiTaskModelFileName, unilib_.get());
89 }
90
LoadMultiTaskSrP13nTestModel()91 std::unique_ptr<ActionsSuggestions> LoadMultiTaskSrP13nTestModel() {
92 return ActionsSuggestions::FromPath(
93 GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get());
94 }
95 std::unique_ptr<UniLib> unilib_;
96 };
97
TEST_F(ActionsSuggestionsTest,InstantiateActionSuggestions)98 TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
99 EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
100 }
101
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidInput)102 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
103 std::unique_ptr<ActionsSuggestions> actions_suggestions =
104 LoadTestModel(kModelFileName);
105 const ActionsSuggestionsResponse response =
106 actions_suggestions->SuggestActions(
107 {{{/*user_id=*/1, "Where are you?\xf0\x9f",
108 /*reference_time_ms_utc=*/0,
109 /*reference_timezone=*/"Europe/Zurich",
110 /*annotations=*/{}, /*locales=*/"en"}}});
111 EXPECT_THAT(response.actions, IsEmpty());
112 }
113
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidUtf8)114 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
115 std::unique_ptr<ActionsSuggestions> actions_suggestions =
116 LoadTestModel(kModelFileName);
117
118 const ActionsSuggestionsResponse response =
119 actions_suggestions->SuggestActions(
120 {{{/*user_id=*/1,
121 "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
122 /*reference_time_ms_utc=*/0,
123 /*reference_timezone=*/"Europe/Zurich",
124 /*annotations=*/{}, /*locales=*/"en"}}});
125 EXPECT_THAT(response.actions, IsEmpty());
126 }
127
TEST_F(ActionsSuggestionsTest,SuggestsActions)128 TEST_F(ActionsSuggestionsTest, SuggestsActions) {
129 std::unique_ptr<ActionsSuggestions> actions_suggestions =
130 LoadTestModel(kModelFileName);
131 const ActionsSuggestionsResponse response =
132 actions_suggestions->SuggestActions(
133 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
134 /*reference_timezone=*/"Europe/Zurich",
135 /*annotations=*/{}, /*locales=*/"en"}}});
136 EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
137 }
138
TEST_F(ActionsSuggestionsTest,SuggestsNoActionsForUnknownLocale)139 TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
140 std::unique_ptr<ActionsSuggestions> actions_suggestions =
141 LoadTestModel(kModelFileName);
142 const ActionsSuggestionsResponse response =
143 actions_suggestions->SuggestActions(
144 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
145 /*reference_timezone=*/"Europe/Zurich",
146 /*annotations=*/{}, /*locales=*/"zz"}}});
147 EXPECT_THAT(response.actions, testing::IsEmpty());
148 }
149
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotations)150 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
151 std::unique_ptr<ActionsSuggestions> actions_suggestions =
152 LoadTestModel(kModelFileName);
153 AnnotatedSpan annotation;
154 annotation.span = {11, 15};
155 annotation.classification = {ClassificationResult("address", 1.0)};
156 const ActionsSuggestionsResponse response =
157 actions_suggestions->SuggestActions(
158 {{{/*user_id=*/1, "are you at home?",
159 /*reference_time_ms_utc=*/0,
160 /*reference_timezone=*/"Europe/Zurich",
161 /*annotations=*/{annotation},
162 /*locales=*/"en"}}});
163 ASSERT_GE(response.actions.size(), 1);
164 EXPECT_EQ(response.actions.front().type, "view_map");
165 EXPECT_EQ(response.actions.front().score, 1.0);
166 }
167
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithEntityData)168 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) {
169 const std::string actions_model_string =
170 ReadFile(GetModelPath() + kModelFileName);
171 std::unique_ptr<ActionsModelT> actions_model =
172 UnPackActionsModel(actions_model_string.c_str());
173 SetTestEntityDataSchema(actions_model.get());
174
175 // Set custom actions from annotations config.
176 actions_model->annotation_actions_spec->annotation_mapping.clear();
177 actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
178 new AnnotationActionsSpec_::AnnotationMappingT);
179 AnnotationActionsSpec_::AnnotationMappingT* mapping =
180 actions_model->annotation_actions_spec->annotation_mapping.back().get();
181 mapping->annotation_collection = "address";
182 mapping->action.reset(new ActionSuggestionSpecT);
183 mapping->action->type = "save_location";
184 mapping->action->score = 1.0;
185 mapping->action->priority_score = 2.0;
186 mapping->entity_field.reset(new FlatbufferFieldPathT);
187 mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
188 mapping->entity_field->field.back()->field_name = "location";
189
190 flatbuffers::FlatBufferBuilder builder;
191 FinishActionsModelBuffer(builder,
192 ActionsModel::Pack(builder, actions_model.get()));
193 std::unique_ptr<ActionsSuggestions> actions_suggestions =
194 ActionsSuggestions::FromUnownedBuffer(
195 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
196 builder.GetSize(), unilib_.get());
197
198 AnnotatedSpan annotation;
199 annotation.span = {11, 15};
200 annotation.classification = {ClassificationResult("address", 1.0)};
201 const ActionsSuggestionsResponse response =
202 actions_suggestions->SuggestActions(
203 {{{/*user_id=*/1, "are you at home?",
204 /*reference_time_ms_utc=*/0,
205 /*reference_timezone=*/"Europe/Zurich",
206 /*annotations=*/{annotation},
207 /*locales=*/"en"}}});
208 ASSERT_GE(response.actions.size(), 1);
209 EXPECT_EQ(response.actions.front().type, "save_location");
210 EXPECT_EQ(response.actions.front().score, 1.0);
211
212 // Check that the `location` entity field holds the text from the address
213 // annotation.
214 const flatbuffers::Table* entity =
215 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
216 response.actions.front().serialized_entity_data.data()));
217 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
218 "home");
219 }
220
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithNormalization)221 TEST_F(ActionsSuggestionsTest,
222 SuggestsActionsFromAnnotationsWithNormalization) {
223 const std::string actions_model_string =
224 ReadFile(GetModelPath() + kModelFileName);
225 std::unique_ptr<ActionsModelT> actions_model =
226 UnPackActionsModel(actions_model_string.c_str());
227 SetTestEntityDataSchema(actions_model.get());
228
229 // Set custom actions from annotations config.
230 actions_model->annotation_actions_spec->annotation_mapping.clear();
231 actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
232 new AnnotationActionsSpec_::AnnotationMappingT);
233 AnnotationActionsSpec_::AnnotationMappingT* mapping =
234 actions_model->annotation_actions_spec->annotation_mapping.back().get();
235 mapping->annotation_collection = "address";
236 mapping->action.reset(new ActionSuggestionSpecT);
237 mapping->action->type = "save_location";
238 mapping->action->score = 1.0;
239 mapping->action->priority_score = 2.0;
240 mapping->entity_field.reset(new FlatbufferFieldPathT);
241 mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
242 mapping->entity_field->field.back()->field_name = "location";
243 mapping->normalization_options.reset(new NormalizationOptionsT);
244 mapping->normalization_options->codepointwise_normalization =
245 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
246
247 flatbuffers::FlatBufferBuilder builder;
248 FinishActionsModelBuffer(builder,
249 ActionsModel::Pack(builder, actions_model.get()));
250 std::unique_ptr<ActionsSuggestions> actions_suggestions =
251 ActionsSuggestions::FromUnownedBuffer(
252 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
253 builder.GetSize(), unilib_.get());
254
255 AnnotatedSpan annotation;
256 annotation.span = {11, 15};
257 annotation.classification = {ClassificationResult("address", 1.0)};
258 const ActionsSuggestionsResponse response =
259 actions_suggestions->SuggestActions(
260 {{{/*user_id=*/1, "are you at home?",
261 /*reference_time_ms_utc=*/0,
262 /*reference_timezone=*/"Europe/Zurich",
263 /*annotations=*/{annotation},
264 /*locales=*/"en"}}});
265 ASSERT_GE(response.actions.size(), 1);
266 EXPECT_EQ(response.actions.front().type, "save_location");
267 EXPECT_EQ(response.actions.front().score, 1.0);
268
269 // Check that the `location` entity field holds the normalized text of the
270 // annotation.
271 const flatbuffers::Table* entity =
272 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
273 response.actions.front().serialized_entity_data.data()));
274 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
275 "HOME");
276 }
277
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromDuplicatedAnnotations)278 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
279 std::unique_ptr<ActionsSuggestions> actions_suggestions =
280 LoadTestModel(kModelFileName);
281 AnnotatedSpan flight_annotation;
282 flight_annotation.span = {11, 15};
283 flight_annotation.classification = {ClassificationResult("flight", 2.5)};
284 AnnotatedSpan flight_annotation2;
285 flight_annotation2.span = {35, 39};
286 flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
287 AnnotatedSpan email_annotation;
288 email_annotation.span = {43, 56};
289 email_annotation.classification = {ClassificationResult("email", 2.0)};
290
291 const ActionsSuggestionsResponse response =
292 actions_suggestions->SuggestActions(
293 {{{/*user_id=*/1,
294 "call me at LX38 or send message to LX38 or test@test.com.",
295 /*reference_time_ms_utc=*/0,
296 /*reference_timezone=*/"Europe/Zurich",
297 /*annotations=*/
298 {flight_annotation, flight_annotation2, email_annotation},
299 /*locales=*/"en"}}});
300
301 ASSERT_GE(response.actions.size(), 2);
302 EXPECT_EQ(response.actions[0].type, "track_flight");
303 EXPECT_EQ(response.actions[0].score, 3.0);
304 EXPECT_EQ(response.actions[1].type, "send_email");
305 EXPECT_EQ(response.actions[1].score, 2.0);
306 }
307
TEST_F(ActionsSuggestionsTest,SuggestsActionsAnnotationsWithNoDeduplication)308 TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) {
309 const std::string actions_model_string =
310 ReadFile(GetModelPath() + kModelFileName);
311 std::unique_ptr<ActionsModelT> actions_model =
312 UnPackActionsModel(actions_model_string.c_str());
313 // Disable deduplication.
314 actions_model->annotation_actions_spec->deduplicate_annotations = false;
315 flatbuffers::FlatBufferBuilder builder;
316 FinishActionsModelBuffer(builder,
317 ActionsModel::Pack(builder, actions_model.get()));
318 std::unique_ptr<ActionsSuggestions> actions_suggestions =
319 ActionsSuggestions::FromUnownedBuffer(
320 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
321 builder.GetSize(), unilib_.get());
322 AnnotatedSpan flight_annotation;
323 flight_annotation.span = {11, 15};
324 flight_annotation.classification = {ClassificationResult("flight", 2.5)};
325 AnnotatedSpan flight_annotation2;
326 flight_annotation2.span = {35, 39};
327 flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
328 AnnotatedSpan email_annotation;
329 email_annotation.span = {43, 56};
330 email_annotation.classification = {ClassificationResult("email", 2.0)};
331
332 const ActionsSuggestionsResponse response =
333 actions_suggestions->SuggestActions(
334 {{{/*user_id=*/1,
335 "call me at LX38 or send message to LX38 or test@test.com.",
336 /*reference_time_ms_utc=*/0,
337 /*reference_timezone=*/"Europe/Zurich",
338 /*annotations=*/
339 {flight_annotation, flight_annotation2, email_annotation},
340 /*locales=*/"en"}}});
341
342 ASSERT_GE(response.actions.size(), 3);
343 EXPECT_EQ(response.actions[0].type, "track_flight");
344 EXPECT_EQ(response.actions[0].score, 3.0);
345 EXPECT_EQ(response.actions[1].type, "track_flight");
346 EXPECT_EQ(response.actions[1].score, 2.5);
347 EXPECT_EQ(response.actions[2].type, "send_email");
348 EXPECT_EQ(response.actions[2].score, 2.0);
349 }
350
TestSuggestActionsFromAnnotations(const std::function<void (ActionsModelT *)> & set_config_fn,const UniLib * unilib=nullptr)351 ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
352 const std::function<void(ActionsModelT*)>& set_config_fn,
353 const UniLib* unilib = nullptr) {
354 const std::string actions_model_string =
355 ReadFile(GetModelPath() + kModelFileName);
356 std::unique_ptr<ActionsModelT> actions_model =
357 UnPackActionsModel(actions_model_string.c_str());
358
359 // Set custom config.
360 set_config_fn(actions_model.get());
361
362 // Disable smart reply for easier testing.
363 actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
364
365 flatbuffers::FlatBufferBuilder builder;
366 FinishActionsModelBuffer(builder,
367 ActionsModel::Pack(builder, actions_model.get()));
368 std::unique_ptr<ActionsSuggestions> actions_suggestions =
369 ActionsSuggestions::FromUnownedBuffer(
370 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
371 builder.GetSize(), unilib);
372
373 AnnotatedSpan flight_annotation;
374 flight_annotation.span = {15, 19};
375 flight_annotation.classification = {ClassificationResult("flight", 2.0)};
376 AnnotatedSpan email_annotation;
377 email_annotation.span = {0, 16};
378 email_annotation.classification = {ClassificationResult("email", 1.0)};
379
380 return actions_suggestions->SuggestActions(
381 {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
382 "hehe@android.com",
383 /*reference_time_ms_utc=*/0,
384 /*reference_timezone=*/"Europe/Zurich",
385 /*annotations=*/
386 {email_annotation},
387 /*locales=*/"en"},
388 {/*user_id=*/2,
389 "yoyo@android.com",
390 /*reference_time_ms_utc=*/0,
391 /*reference_timezone=*/"Europe/Zurich",
392 /*annotations=*/
393 {email_annotation},
394 /*locales=*/"en"},
395 {/*user_id=*/1,
396 "test@android.com",
397 /*reference_time_ms_utc=*/0,
398 /*reference_timezone=*/"Europe/Zurich",
399 /*annotations=*/
400 {email_annotation},
401 /*locales=*/"en"},
402 {/*user_id=*/1,
403 "I am on flight LX38.",
404 /*reference_time_ms_utc=*/0,
405 /*reference_timezone=*/"Europe/Zurich",
406 /*annotations=*/
407 {flight_annotation},
408 /*locales=*/"en"}}});
409 }
410
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastMessage)411 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) {
412 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
413 [](ActionsModelT* actions_model) {
414 actions_model->annotation_actions_spec->include_local_user_messages =
415 false;
416 actions_model->annotation_actions_spec->only_until_last_sent = true;
417 actions_model->annotation_actions_spec->max_history_from_any_person = 1;
418 actions_model->annotation_actions_spec->max_history_from_last_person =
419 1;
420 },
421 unilib_.get());
422 EXPECT_THAT(response.actions, SizeIs(1));
423 EXPECT_EQ(response.actions[0].type, "track_flight");
424 }
425
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastPerson)426 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) {
427 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
428 [](ActionsModelT* actions_model) {
429 actions_model->annotation_actions_spec->include_local_user_messages =
430 false;
431 actions_model->annotation_actions_spec->only_until_last_sent = true;
432 actions_model->annotation_actions_spec->max_history_from_any_person = 1;
433 actions_model->annotation_actions_spec->max_history_from_last_person =
434 3;
435 },
436 unilib_.get());
437 EXPECT_THAT(response.actions, SizeIs(2));
438 EXPECT_EQ(response.actions[0].type, "track_flight");
439 EXPECT_EQ(response.actions[1].type, "send_email");
440 }
441
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAny)442 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) {
443 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
444 [](ActionsModelT* actions_model) {
445 actions_model->annotation_actions_spec->include_local_user_messages =
446 false;
447 actions_model->annotation_actions_spec->only_until_last_sent = true;
448 actions_model->annotation_actions_spec->max_history_from_any_person = 2;
449 actions_model->annotation_actions_spec->max_history_from_last_person =
450 1;
451 },
452 unilib_.get());
453 EXPECT_THAT(response.actions, SizeIs(2));
454 EXPECT_EQ(response.actions[0].type, "track_flight");
455 EXPECT_EQ(response.actions[1].type, "send_email");
456 }
457
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessages)458 TEST_F(ActionsSuggestionsTest,
459 SuggestsActionsWithAnnotationsFromAnyManyMessages) {
460 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
461 [](ActionsModelT* actions_model) {
462 actions_model->annotation_actions_spec->include_local_user_messages =
463 false;
464 actions_model->annotation_actions_spec->only_until_last_sent = true;
465 actions_model->annotation_actions_spec->max_history_from_any_person = 3;
466 actions_model->annotation_actions_spec->max_history_from_last_person =
467 1;
468 },
469 unilib_.get());
470 EXPECT_THAT(response.actions, SizeIs(3));
471 EXPECT_EQ(response.actions[0].type, "track_flight");
472 EXPECT_EQ(response.actions[1].type, "send_email");
473 EXPECT_EQ(response.actions[2].type, "send_email");
474 }
475
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser)476 TEST_F(ActionsSuggestionsTest,
477 SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
478 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
479 [](ActionsModelT* actions_model) {
480 actions_model->annotation_actions_spec->include_local_user_messages =
481 false;
482 actions_model->annotation_actions_spec->only_until_last_sent = true;
483 actions_model->annotation_actions_spec->max_history_from_any_person = 5;
484 actions_model->annotation_actions_spec->max_history_from_last_person =
485 1;
486 },
487 unilib_.get());
488 EXPECT_THAT(response.actions, SizeIs(3));
489 EXPECT_EQ(response.actions[0].type, "track_flight");
490 EXPECT_EQ(response.actions[1].type, "send_email");
491 EXPECT_EQ(response.actions[2].type, "send_email");
492 }
493
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser)494 TEST_F(ActionsSuggestionsTest,
495 SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
496 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
497 [](ActionsModelT* actions_model) {
498 actions_model->annotation_actions_spec->include_local_user_messages =
499 true;
500 actions_model->annotation_actions_spec->only_until_last_sent = false;
501 actions_model->annotation_actions_spec->max_history_from_any_person = 5;
502 actions_model->annotation_actions_spec->max_history_from_last_person =
503 1;
504 },
505 unilib_.get());
506 EXPECT_THAT(response.actions, SizeIs(4));
507 EXPECT_EQ(response.actions[0].type, "track_flight");
508 EXPECT_EQ(response.actions[1].type, "send_email");
509 EXPECT_EQ(response.actions[2].type, "send_email");
510 EXPECT_EQ(response.actions[3].type, "send_email");
511 }
512
TestSuggestActionsWithThreshold(const std::function<void (ActionsModelT *)> & set_value_fn,const UniLib * unilib=nullptr,const int expected_size=0,const std::string & preconditions_overwrite="")513 void TestSuggestActionsWithThreshold(
514 const std::function<void(ActionsModelT*)>& set_value_fn,
515 const UniLib* unilib = nullptr, const int expected_size = 0,
516 const std::string& preconditions_overwrite = "") {
517 const std::string actions_model_string =
518 ReadFile(GetModelPath() + kModelFileName);
519 std::unique_ptr<ActionsModelT> actions_model =
520 UnPackActionsModel(actions_model_string.c_str());
521 set_value_fn(actions_model.get());
522 flatbuffers::FlatBufferBuilder builder;
523 FinishActionsModelBuffer(builder,
524 ActionsModel::Pack(builder, actions_model.get()));
525 std::unique_ptr<ActionsSuggestions> actions_suggestions =
526 ActionsSuggestions::FromUnownedBuffer(
527 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
528 builder.GetSize(), unilib, preconditions_overwrite);
529 ASSERT_TRUE(actions_suggestions);
530 const ActionsSuggestionsResponse response =
531 actions_suggestions->SuggestActions(
532 {{{/*user_id=*/1, "I have the low-ground. Where are you?",
533 /*reference_time_ms_utc=*/0,
534 /*reference_timezone=*/"Europe/Zurich",
535 /*annotations=*/{}, /*locales=*/"en"}}});
536 EXPECT_LE(response.actions.size(), expected_size);
537 }
538
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithTriggeringScore)539 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) {
540 TestSuggestActionsWithThreshold(
541 [](ActionsModelT* actions_model) {
542 actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
543 },
544 unilib_.get(),
545 /*expected_size=*/1 /*no smart reply, only actions*/
546 );
547 }
548
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinReplyScore)549 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) {
550 TestSuggestActionsWithThreshold(
551 [](ActionsModelT* actions_model) {
552 actions_model->preconditions->min_reply_score_threshold = 1.0;
553 },
554 unilib_.get(),
555 /*expected_size=*/1 /*no smart reply, only actions*/
556 );
557 }
558
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithSensitiveTopicScore)559 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) {
560 TestSuggestActionsWithThreshold(
561 [](ActionsModelT* actions_model) {
562 actions_model->preconditions->max_sensitive_topic_score = 0.0;
563 },
564 unilib_.get(),
565 /*expected_size=*/4 /* no sensitive prediction in test model*/);
566 }
567
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMaxInputLength)568 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) {
569 TestSuggestActionsWithThreshold(
570 [](ActionsModelT* actions_model) {
571 actions_model->preconditions->max_input_length = 0;
572 },
573 unilib_.get());
574 }
575
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinInputLength)576 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
577 TestSuggestActionsWithThreshold(
578 [](ActionsModelT* actions_model) {
579 actions_model->preconditions->min_input_length = 100;
580 },
581 unilib_.get());
582 }
583
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithPreconditionsOverwrite)584 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
585 TriggeringPreconditionsT preconditions_overwrite;
586 preconditions_overwrite.max_input_length = 0;
587 flatbuffers::FlatBufferBuilder builder;
588 builder.Finish(
589 TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
590 TestSuggestActionsWithThreshold(
591 // Keep model untouched.
592 [](ActionsModelT* actions_model) {}, unilib_.get(),
593 /*expected_size=*/0,
594 std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
595 builder.GetSize()));
596 }
597
598 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidence)599 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) {
600 TestSuggestActionsWithThreshold(
601 [](ActionsModelT* actions_model) {
602 actions_model->preconditions->suppress_on_low_confidence_input = true;
603 actions_model->low_confidence_rules.reset(new RulesModelT);
604 actions_model->low_confidence_rules->regex_rule.emplace_back(
605 new RulesModel_::RegexRuleT);
606 actions_model->low_confidence_rules->regex_rule.back()->pattern =
607 "low-ground";
608 },
609 unilib_.get());
610 }
611
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutput)612 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
613 const std::string actions_model_string =
614 ReadFile(GetModelPath() + kModelFileName);
615 std::unique_ptr<ActionsModelT> actions_model =
616 UnPackActionsModel(actions_model_string.c_str());
617 // Add custom triggering rule.
618 actions_model->rules.reset(new RulesModelT());
619 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
620 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
621 rule->pattern = "^(?i:hello\\s(there))$";
622 {
623 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
624 new RulesModel_::RuleActionSpecT);
625 rule_action->action.reset(new ActionSuggestionSpecT);
626 rule_action->action->type = "text_reply";
627 rule_action->action->response_text = "General Desaster!";
628 rule_action->action->score = 1.0f;
629 rule_action->action->priority_score = 1.0f;
630 rule->actions.push_back(std::move(rule_action));
631 }
632 {
633 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
634 new RulesModel_::RuleActionSpecT);
635 rule_action->action.reset(new ActionSuggestionSpecT);
636 rule_action->action->type = "text_reply";
637 rule_action->action->response_text = "General Kenobi!";
638 rule_action->action->score = 1.0f;
639 rule_action->action->priority_score = 1.0f;
640 rule->actions.push_back(std::move(rule_action));
641 }
642
643 // Add input-output low confidence rule.
644 actions_model->preconditions->suppress_on_low_confidence_input = true;
645 actions_model->low_confidence_rules.reset(new RulesModelT);
646 actions_model->low_confidence_rules->regex_rule.emplace_back(
647 new RulesModel_::RegexRuleT);
648 actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
649 actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
650 "(?i:desaster)";
651
652 flatbuffers::FlatBufferBuilder builder;
653 FinishActionsModelBuffer(builder,
654 ActionsModel::Pack(builder, actions_model.get()));
655 std::unique_ptr<ActionsSuggestions> actions_suggestions =
656 ActionsSuggestions::FromUnownedBuffer(
657 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
658 builder.GetSize(), unilib_.get());
659 ASSERT_TRUE(actions_suggestions);
660 const ActionsSuggestionsResponse response =
661 actions_suggestions->SuggestActions(
662 {{{/*user_id=*/1, "hello there",
663 /*reference_time_ms_utc=*/0,
664 /*reference_timezone=*/"Europe/Zurich",
665 /*annotations=*/{}, /*locales=*/"en"}}});
666 ASSERT_GE(response.actions.size(), 1);
667 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
668 }
669
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutputOverwrite)670 TEST_F(ActionsSuggestionsTest,
671 SuggestsActionsLowConfidenceInputOutputOverwrite) {
672 const std::string actions_model_string =
673 ReadFile(GetModelPath() + kModelFileName);
674 std::unique_ptr<ActionsModelT> actions_model =
675 UnPackActionsModel(actions_model_string.c_str());
676 actions_model->low_confidence_rules.reset();
677
678 // Add custom triggering rule.
679 actions_model->rules.reset(new RulesModelT());
680 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
681 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
682 rule->pattern = "^(?i:hello\\s(there))$";
683 {
684 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
685 new RulesModel_::RuleActionSpecT);
686 rule_action->action.reset(new ActionSuggestionSpecT);
687 rule_action->action->type = "text_reply";
688 rule_action->action->response_text = "General Desaster!";
689 rule_action->action->score = 1.0f;
690 rule_action->action->priority_score = 1.0f;
691 rule->actions.push_back(std::move(rule_action));
692 }
693 {
694 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
695 new RulesModel_::RuleActionSpecT);
696 rule_action->action.reset(new ActionSuggestionSpecT);
697 rule_action->action->type = "text_reply";
698 rule_action->action->response_text = "General Kenobi!";
699 rule_action->action->score = 1.0f;
700 rule_action->action->priority_score = 1.0f;
701 rule->actions.push_back(std::move(rule_action));
702 }
703
704 // Add custom triggering rule via overwrite.
705 actions_model->preconditions->low_confidence_rules.reset();
706 TriggeringPreconditionsT preconditions;
707 preconditions.suppress_on_low_confidence_input = true;
708 preconditions.low_confidence_rules.reset(new RulesModelT);
709 preconditions.low_confidence_rules->regex_rule.emplace_back(
710 new RulesModel_::RegexRuleT);
711 preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
712 preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
713 "(?i:desaster)";
714 flatbuffers::FlatBufferBuilder preconditions_builder;
715 preconditions_builder.Finish(
716 TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
717 std::string serialize_preconditions = std::string(
718 reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
719 preconditions_builder.GetSize());
720
721 flatbuffers::FlatBufferBuilder builder;
722 FinishActionsModelBuffer(builder,
723 ActionsModel::Pack(builder, actions_model.get()));
724 std::unique_ptr<ActionsSuggestions> actions_suggestions =
725 ActionsSuggestions::FromUnownedBuffer(
726 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
727 builder.GetSize(), unilib_.get(), serialize_preconditions);
728
729 ASSERT_TRUE(actions_suggestions);
730 const ActionsSuggestionsResponse response =
731 actions_suggestions->SuggestActions(
732 {{{/*user_id=*/1, "hello there",
733 /*reference_time_ms_utc=*/0,
734 /*reference_timezone=*/"Europe/Zurich",
735 /*annotations=*/{}, /*locales=*/"en"}}});
736 ASSERT_GE(response.actions.size(), 1);
737 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
738 }
739 #endif
740
TEST_F(ActionsSuggestionsTest,SuppressActionsFromAnnotationsOnSensitiveTopic)741 TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
742 const std::string actions_model_string =
743 ReadFile(GetModelPath() + kModelFileName);
744 std::unique_ptr<ActionsModelT> actions_model =
745 UnPackActionsModel(actions_model_string.c_str());
746
747 // Don't test if no sensitivity score is produced
748 if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
749 return;
750 }
751
752 actions_model->preconditions->max_sensitive_topic_score = 0.0;
753 actions_model->preconditions->suppress_on_sensitive_topic = true;
754 flatbuffers::FlatBufferBuilder builder;
755 FinishActionsModelBuffer(builder,
756 ActionsModel::Pack(builder, actions_model.get()));
757 std::unique_ptr<ActionsSuggestions> actions_suggestions =
758 ActionsSuggestions::FromUnownedBuffer(
759 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
760 builder.GetSize(), unilib_.get());
761 AnnotatedSpan annotation;
762 annotation.span = {11, 15};
763 annotation.classification = {
764 ClassificationResult(Collections::Address(), 1.0)};
765 const ActionsSuggestionsResponse response =
766 actions_suggestions->SuggestActions(
767 {{{/*user_id=*/1, "are you at home?",
768 /*reference_time_ms_utc=*/0,
769 /*reference_timezone=*/"Europe/Zurich",
770 /*annotations=*/{annotation},
771 /*locales=*/"en"}}});
772 EXPECT_THAT(response.actions, testing::IsEmpty());
773 }
774
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithLongerConversation)775 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) {
776 const std::string actions_model_string =
777 ReadFile(GetModelPath() + kModelFileName);
778 std::unique_ptr<ActionsModelT> actions_model =
779 UnPackActionsModel(actions_model_string.c_str());
780
781 // Allow a larger conversation context.
782 actions_model->max_conversation_history_length = 10;
783
784 flatbuffers::FlatBufferBuilder builder;
785 FinishActionsModelBuffer(builder,
786 ActionsModel::Pack(builder, actions_model.get()));
787 std::unique_ptr<ActionsSuggestions> actions_suggestions =
788 ActionsSuggestions::FromUnownedBuffer(
789 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
790 builder.GetSize(), unilib_.get());
791 AnnotatedSpan annotation;
792 annotation.span = {11, 15};
793 annotation.classification = {
794 ClassificationResult(Collections::Address(), 1.0)};
795 const ActionsSuggestionsResponse response =
796 actions_suggestions->SuggestActions(
797 {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
798 /*reference_time_ms_utc=*/10000,
799 /*reference_timezone=*/"Europe/Zurich",
800 /*annotations=*/{}, /*locales=*/"en"},
801 {/*user_id=*/1, "good! are you at home?",
802 /*reference_time_ms_utc=*/15000,
803 /*reference_timezone=*/"Europe/Zurich",
804 /*annotations=*/{annotation},
805 /*locales=*/"en"}}});
806 ASSERT_GE(response.actions.size(), 1);
807 EXPECT_EQ(response.actions[0].type, "view_map");
808 EXPECT_EQ(response.actions[0].score, 1.0);
809 }
810
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromTF2MultiTaskModel)811 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) {
812 std::unique_ptr<ActionsSuggestions> actions_suggestions =
813 LoadTestModel(kMultiTaskTF2TestModelFileName);
814 const ActionsSuggestionsResponse response =
815 actions_suggestions->SuggestActions(
816 {{{/*user_id=*/1, "Hello how are you",
817 /*reference_time_ms_utc=*/0,
818 /*reference_timezone=*/"Europe/Zurich",
819 /*annotations=*/{},
820 /*locales=*/"en"}}});
821 EXPECT_EQ(response.actions.size(), 4);
822 EXPECT_EQ(response.actions[0].response_text, "Okay");
823 EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION");
824 EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT");
825 }
826
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromPhoneGrammarAnnotations)827 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
828 std::unique_ptr<ActionsSuggestions> actions_suggestions =
829 LoadTestModel(kModelGrammarFileName);
830 AnnotatedSpan annotation;
831 annotation.span = {11, 15};
832 annotation.classification = {ClassificationResult("phone", 0.0)};
833 const ActionsSuggestionsResponse response =
834 actions_suggestions->SuggestActions(
835 {{{/*user_id=*/1, "Contact us at: *1234",
836 /*reference_time_ms_utc=*/0,
837 /*reference_timezone=*/"Europe/Zurich",
838 /*annotations=*/{annotation},
839 /*locales=*/"en"}}});
840 ASSERT_GE(response.actions.size(), 1);
841 EXPECT_EQ(response.actions.front().type, "call_phone");
842 EXPECT_EQ(response.actions.front().score, 0.0);
843 EXPECT_EQ(response.actions.front().priority_score, 0.0);
844 EXPECT_EQ(response.actions.front().annotations.size(), 1);
845 EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
846 EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
847 }
848
TEST_F(ActionsSuggestionsTest,CreateActionsFromClassificationResult)849 TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
850 std::unique_ptr<ActionsSuggestions> actions_suggestions =
851 LoadTestModel(kModelFileName);
852 AnnotatedSpan annotation;
853 annotation.span = {8, 12};
854 annotation.classification = {
855 ClassificationResult(Collections::Flight(), 1.0)};
856
857 const ActionsSuggestionsResponse response =
858 actions_suggestions->SuggestActions(
859 {{{/*user_id=*/1, "I'm on LX38?",
860 /*reference_time_ms_utc=*/0,
861 /*reference_timezone=*/"Europe/Zurich",
862 /*annotations=*/{annotation},
863 /*locales=*/"en"}}});
864
865 ASSERT_GE(response.actions.size(), 2);
866 EXPECT_EQ(response.actions[0].type, "track_flight");
867 EXPECT_EQ(response.actions[0].score, 1.0);
868 EXPECT_THAT(response.actions[0].annotations, SizeIs(1));
869 EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
870 EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
871 }
872
873 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,CreateActionsFromRules)874 TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
875 const std::string actions_model_string =
876 ReadFile(GetModelPath() + kModelFileName);
877 std::unique_ptr<ActionsModelT> actions_model =
878 UnPackActionsModel(actions_model_string.c_str());
879 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
880
881 actions_model->rules.reset(new RulesModelT());
882 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
883 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
884 rule->pattern = "^(?i:hello\\s(there))$";
885 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
886 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
887 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
888 action->type = "text_reply";
889 action->response_text = "General Kenobi!";
890 action->score = 1.0f;
891 action->priority_score = 1.0f;
892
893 // Set capturing groups for entity data.
894 rule->actions.back()->capturing_group.emplace_back(
895 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
896 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
897 rule->actions.back()->capturing_group.back().get();
898 greeting_group->group_id = 0;
899 greeting_group->entity_field.reset(new FlatbufferFieldPathT);
900 greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
901 greeting_group->entity_field->field.back()->field_name = "greeting";
902 rule->actions.back()->capturing_group.emplace_back(
903 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
904 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
905 rule->actions.back()->capturing_group.back().get();
906 location_group->group_id = 1;
907 location_group->entity_field.reset(new FlatbufferFieldPathT);
908 location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
909 location_group->entity_field->field.back()->field_name = "location";
910
911 // Set test entity data schema.
912 SetTestEntityDataSchema(actions_model.get());
913
914 // Use meta data to generate custom serialized entity data.
915 MutableFlatbufferBuilder entity_data_builder(
916 flatbuffers::GetRoot<reflection::Schema>(
917 actions_model->actions_entity_data_schema.data()));
918 std::unique_ptr<MutableFlatbuffer> entity_data =
919 entity_data_builder.NewRoot();
920 entity_data->Set("person", "Kenobi");
921 action->serialized_entity_data = entity_data->Serialize();
922
923 flatbuffers::FlatBufferBuilder builder;
924 FinishActionsModelBuffer(builder,
925 ActionsModel::Pack(builder, actions_model.get()));
926 std::unique_ptr<ActionsSuggestions> actions_suggestions =
927 ActionsSuggestions::FromUnownedBuffer(
928 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
929 builder.GetSize(), unilib_.get());
930
931 const ActionsSuggestionsResponse response =
932 actions_suggestions->SuggestActions(
933 {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
934 /*reference_timezone=*/"Europe/Zurich",
935 /*annotations=*/{}, /*locales=*/"en"}}});
936 EXPECT_GE(response.actions.size(), 1);
937 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
938
939 // Check entity data.
940 const flatbuffers::Table* entity =
941 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
942 response.actions[0].serialized_entity_data.data()));
943 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
944 "hello there");
945 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
946 "there");
947 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
948 "Kenobi");
949 }
950
TEST_F(ActionsSuggestionsTest,CreateActionsFromRulesWithNormalization)951 TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
952 const std::string actions_model_string =
953 ReadFile(GetModelPath() + kModelFileName);
954 std::unique_ptr<ActionsModelT> actions_model =
955 UnPackActionsModel(actions_model_string.c_str());
956 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
957
958 actions_model->rules.reset(new RulesModelT());
959 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
960 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
961 rule->pattern = "^(?i:hello\\sthere)$";
962 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
963 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
964 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
965 action->type = "text_reply";
966 action->response_text = "General Kenobi!";
967 action->score = 1.0f;
968 action->priority_score = 1.0f;
969
970 // Set capturing groups for entity data.
971 rule->actions.back()->capturing_group.emplace_back(
972 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
973 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
974 rule->actions.back()->capturing_group.back().get();
975 greeting_group->group_id = 0;
976 greeting_group->entity_field.reset(new FlatbufferFieldPathT);
977 greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
978 greeting_group->entity_field->field.back()->field_name = "greeting";
979 greeting_group->normalization_options.reset(new NormalizationOptionsT);
980 greeting_group->normalization_options->codepointwise_normalization =
981 NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
982 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
983
984 // Set test entity data schema.
985 SetTestEntityDataSchema(actions_model.get());
986
987 flatbuffers::FlatBufferBuilder builder;
988 FinishActionsModelBuffer(builder,
989 ActionsModel::Pack(builder, actions_model.get()));
990 std::unique_ptr<ActionsSuggestions> actions_suggestions =
991 ActionsSuggestions::FromUnownedBuffer(
992 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
993 builder.GetSize(), unilib_.get());
994
995 const ActionsSuggestionsResponse response =
996 actions_suggestions->SuggestActions(
997 {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
998 /*reference_timezone=*/"Europe/Zurich",
999 /*annotations=*/{}, /*locales=*/"en"}}});
1000 EXPECT_GE(response.actions.size(), 1);
1001 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
1002
1003 // Check entity data.
1004 const flatbuffers::Table* entity =
1005 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
1006 response.actions[0].serialized_entity_data.data()));
1007 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
1008 "HELLOTHERE");
1009 }
1010
TEST_F(ActionsSuggestionsTest,CreatesTextRepliesFromRules)1011 TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
1012 const std::string actions_model_string =
1013 ReadFile(GetModelPath() + kModelFileName);
1014 std::unique_ptr<ActionsModelT> actions_model =
1015 UnPackActionsModel(actions_model_string.c_str());
1016 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1017
1018 actions_model->rules.reset(new RulesModelT());
1019 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1020 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1021 rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
1022 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1023
1024 // Set capturing groups for entity data.
1025 rule->actions.back()->capturing_group.emplace_back(
1026 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1027 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1028 rule->actions.back()->capturing_group.back().get();
1029 code_group->group_id = 1;
1030 code_group->text_reply.reset(new ActionSuggestionSpecT);
1031 code_group->text_reply->score = 1.0f;
1032 code_group->text_reply->priority_score = 1.0f;
1033 code_group->normalization_options.reset(new NormalizationOptionsT);
1034 code_group->normalization_options->codepointwise_normalization =
1035 NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
1036
1037 flatbuffers::FlatBufferBuilder builder;
1038 FinishActionsModelBuffer(builder,
1039 ActionsModel::Pack(builder, actions_model.get()));
1040 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1041 ActionsSuggestions::FromUnownedBuffer(
1042 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1043 builder.GetSize(), unilib_.get());
1044
1045 const ActionsSuggestionsResponse response =
1046 actions_suggestions->SuggestActions(
1047 {{{/*user_id=*/1,
1048 "visit test.com or reply STOP to cancel your subscription",
1049 /*reference_time_ms_utc=*/0,
1050 /*reference_timezone=*/"Europe/Zurich",
1051 /*annotations=*/{}, /*locales=*/"en"}}});
1052 EXPECT_GE(response.actions.size(), 1);
1053 EXPECT_EQ(response.actions[0].response_text, "stop");
1054 }
1055
TEST_F(ActionsSuggestionsTest,CreatesActionsFromGrammarRules)1056 TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
1057 const std::string actions_model_string =
1058 ReadFile(GetModelPath() + kModelFileName);
1059 std::unique_ptr<ActionsModelT> actions_model =
1060 UnPackActionsModel(actions_model_string.c_str());
1061 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1062
1063 actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1064
1065 // Set tokenizer options.
1066 RulesModel_::GrammarRulesT* action_grammar_rules =
1067 actions_model->rules->grammar_rules.get();
1068 action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1069 action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1070 action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1071 false;
1072
1073 // Setup test rules.
1074 action_grammar_rules->rules.reset(new grammar::RulesSetT);
1075 grammar::LocaleShardMap locale_shard_map =
1076 grammar::LocaleShardMap::CreateLocaleShardMap({""});
1077 grammar::Rules rules(locale_shard_map);
1078 rules.Add(
1079 "<knock>", {"<^>", "ventura", "!?", "<$>"},
1080 /*callback=*/
1081 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1082 /*callback_param=*/0);
1083 rules.Finalize().Serialize(/*include_debug_information=*/false,
1084 action_grammar_rules->rules.get());
1085 action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1086 RulesModel_::RuleActionSpecT* actions_spec =
1087 action_grammar_rules->actions.back().get();
1088 actions_spec->action.reset(new ActionSuggestionSpecT);
1089 actions_spec->action->response_text = "Yes, Satan?";
1090 actions_spec->action->priority_score = 1.0;
1091 actions_spec->action->score = 1.0;
1092 actions_spec->action->type = "text_reply";
1093 action_grammar_rules->rule_match.emplace_back(
1094 new RulesModel_::GrammarRules_::RuleMatchT);
1095 action_grammar_rules->rule_match.back()->action_id.push_back(0);
1096
1097 flatbuffers::FlatBufferBuilder builder;
1098 FinishActionsModelBuffer(builder,
1099 ActionsModel::Pack(builder, actions_model.get()));
1100 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1101 ActionsSuggestions::FromUnownedBuffer(
1102 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1103 builder.GetSize(), unilib_.get());
1104
1105 const ActionsSuggestionsResponse response =
1106 actions_suggestions->SuggestActions(
1107 {{{/*user_id=*/1, "Ventura!",
1108 /*reference_time_ms_utc=*/0,
1109 /*reference_timezone=*/"Europe/Zurich",
1110 /*annotations=*/{}, /*locales=*/"en"}}});
1111
1112 EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
1113 }
1114
1115 #if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
TEST_F(ActionsSuggestionsTest,CreatesActionsWithAnnotationsFromGrammarRules)1116 TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
1117 std::unique_ptr<Annotator> annotator =
1118 Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
1119 const std::string actions_model_string =
1120 ReadFile(GetModelPath() + kModelFileName);
1121 std::unique_ptr<ActionsModelT> actions_model =
1122 UnPackActionsModel(actions_model_string.c_str());
1123 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1124
1125 actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1126
1127 // Set tokenizer options.
1128 RulesModel_::GrammarRulesT* action_grammar_rules =
1129 actions_model->rules->grammar_rules.get();
1130 action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1131 action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1132 action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1133 false;
1134
1135 // Setup test rules.
1136 action_grammar_rules->rules.reset(new grammar::RulesSetT);
1137 grammar::LocaleShardMap locale_shard_map =
1138 grammar::LocaleShardMap::CreateLocaleShardMap({""});
1139 grammar::Rules rules(locale_shard_map);
1140 rules.Add(
1141 "<event>", {"it", "is", "at", "<time>"},
1142 /*callback=*/
1143 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1144 /*callback_param=*/0);
1145 rules.BindAnnotation("<time>", "time");
1146 rules.AddAnnotation("datetime");
1147 rules.Finalize().Serialize(/*include_debug_information=*/false,
1148 action_grammar_rules->rules.get());
1149 action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1150 RulesModel_::RuleActionSpecT* actions_spec =
1151 action_grammar_rules->actions.back().get();
1152 actions_spec->action.reset(new ActionSuggestionSpecT);
1153 actions_spec->action->priority_score = 1.0;
1154 actions_spec->action->score = 1.0;
1155 actions_spec->action->type = "create_event";
1156 action_grammar_rules->rule_match.emplace_back(
1157 new RulesModel_::GrammarRules_::RuleMatchT);
1158 action_grammar_rules->rule_match.back()->action_id.push_back(0);
1159
1160 flatbuffers::FlatBufferBuilder builder;
1161 FinishActionsModelBuffer(builder,
1162 ActionsModel::Pack(builder, actions_model.get()));
1163 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1164 ActionsSuggestions::FromUnownedBuffer(
1165 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1166 builder.GetSize(), unilib_.get());
1167
1168 const ActionsSuggestionsResponse response =
1169 actions_suggestions->SuggestActions(
1170 {{{/*user_id=*/1, "it is at 10:30",
1171 /*reference_time_ms_utc=*/0,
1172 /*reference_timezone=*/"Europe/Zurich",
1173 /*annotations=*/{}, /*locales=*/"en"}}},
1174 annotator.get());
1175
1176 EXPECT_THAT(response.actions, ElementsAre(IsActionOfType("create_event")));
1177 }
1178 #endif
1179
TEST_F(ActionsSuggestionsTest,DeduplicateActions)1180 TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
1181 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1182 LoadTestModel(kModelFileName);
1183 ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1184 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1185 /*reference_timezone=*/"Europe/Zurich",
1186 /*annotations=*/{}, /*locales=*/"en"}}});
1187
1188 // Check that the location sharing model triggered.
1189 bool has_location_sharing_action = false;
1190 for (const ActionSuggestion& action : response.actions) {
1191 if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
1192 has_location_sharing_action = true;
1193 break;
1194 }
1195 }
1196 EXPECT_TRUE(has_location_sharing_action);
1197 const int num_actions = response.actions.size();
1198
1199 // Add custom rule for location sharing.
1200 const std::string actions_model_string =
1201 ReadFile(GetModelPath() + kModelFileName);
1202 std::unique_ptr<ActionsModelT> actions_model =
1203 UnPackActionsModel(actions_model_string.c_str());
1204 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1205
1206 actions_model->rules.reset(new RulesModelT());
1207 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1208 actions_model->rules->regex_rule.back()->pattern =
1209 "^(?i:where are you[.?]?)$";
1210 actions_model->rules->regex_rule.back()->actions.emplace_back(
1211 new RulesModel_::RuleActionSpecT);
1212 actions_model->rules->regex_rule.back()->actions.back()->action.reset(
1213 new ActionSuggestionSpecT);
1214 ActionSuggestionSpecT* action =
1215 actions_model->rules->regex_rule.back()->actions.back()->action.get();
1216 action->score = 1.0f;
1217 action->type = ActionsSuggestionsTypes::ShareLocation();
1218
1219 flatbuffers::FlatBufferBuilder builder;
1220 FinishActionsModelBuffer(builder,
1221 ActionsModel::Pack(builder, actions_model.get()));
1222 actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1223 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1224 builder.GetSize(), unilib_.get());
1225
1226 response = actions_suggestions->SuggestActions(
1227 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1228 /*reference_timezone=*/"Europe/Zurich",
1229 /*annotations=*/{}, /*locales=*/"en"}}});
1230 EXPECT_THAT(response.actions, SizeIs(num_actions));
1231 }
1232
TEST_F(ActionsSuggestionsTest,DeduplicateConflictingActions)1233 TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
1234 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1235 LoadTestModel(kModelFileName);
1236 AnnotatedSpan annotation;
1237 annotation.span = {7, 11};
1238 annotation.classification = {
1239 ClassificationResult(Collections::Flight(), 1.0)};
1240 ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1241 {{{/*user_id=*/1, "I'm on LX38",
1242 /*reference_time_ms_utc=*/0,
1243 /*reference_timezone=*/"Europe/Zurich",
1244 /*annotations=*/{annotation},
1245 /*locales=*/"en"}}});
1246
1247 // Check that the phone actions are present.
1248 EXPECT_GE(response.actions.size(), 1);
1249 EXPECT_EQ(response.actions[0].type, "track_flight");
1250
1251 // Add custom rule.
1252 const std::string actions_model_string =
1253 ReadFile(GetModelPath() + kModelFileName);
1254 std::unique_ptr<ActionsModelT> actions_model =
1255 UnPackActionsModel(actions_model_string.c_str());
1256 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1257
1258 actions_model->rules.reset(new RulesModelT());
1259 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1260 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1261 rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
1262 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1263 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
1264 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
1265 action->score = 1.0f;
1266 action->priority_score = 2.0f;
1267 action->type = "test_code";
1268 rule->actions.back()->capturing_group.emplace_back(
1269 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1270 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1271 rule->actions.back()->capturing_group.back().get();
1272 code_group->group_id = 1;
1273 code_group->annotation_name = "code";
1274 code_group->annotation_type = "code";
1275
1276 flatbuffers::FlatBufferBuilder builder;
1277 FinishActionsModelBuffer(builder,
1278 ActionsModel::Pack(builder, actions_model.get()));
1279 actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1280 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1281 builder.GetSize(), unilib_.get());
1282
1283 response = actions_suggestions->SuggestActions(
1284 {{{/*user_id=*/1, "I'm on LX38",
1285 /*reference_time_ms_utc=*/0,
1286 /*reference_timezone=*/"Europe/Zurich",
1287 /*annotations=*/{annotation},
1288 /*locales=*/"en"}}});
1289 EXPECT_GE(response.actions.size(), 1);
1290 EXPECT_EQ(response.actions[0].type, "test_code");
1291 }
1292 #endif
1293
TEST_F(ActionsSuggestionsTest,RanksActions)1294 TEST_F(ActionsSuggestionsTest, RanksActions) {
1295 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1296 LoadTestModel(kModelFileName);
1297 std::vector<AnnotatedSpan> annotations(2);
1298 annotations[0].span = {11, 15};
1299 annotations[0].classification = {ClassificationResult("address", 1.0)};
1300 annotations[1].span = {19, 23};
1301 annotations[1].classification = {ClassificationResult("address", 2.0)};
1302 const ActionsSuggestionsResponse response =
1303 actions_suggestions->SuggestActions(
1304 {{{/*user_id=*/1, "are you at home or work?",
1305 /*reference_time_ms_utc=*/0,
1306 /*reference_timezone=*/"Europe/Zurich",
1307 /*annotations=*/annotations,
1308 /*locales=*/"en"}}});
1309 EXPECT_GE(response.actions.size(), 2);
1310 EXPECT_EQ(response.actions[0].type, "view_map");
1311 EXPECT_EQ(response.actions[0].score, 2.0);
1312 EXPECT_EQ(response.actions[1].type, "view_map");
1313 EXPECT_EQ(response.actions[1].score, 1.0);
1314 }
1315
TEST_F(ActionsSuggestionsTest,VisitActionsModel)1316 TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
1317 EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
1318 [](const ActionsModel* model) {
1319 if (model == nullptr) {
1320 return false;
1321 }
1322 return true;
1323 }));
1324 EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
1325 [](const ActionsModel* model) {
1326 if (model == nullptr) {
1327 return false;
1328 }
1329 return true;
1330 }));
1331 }
1332
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithHashGramModel)1333 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithHashGramModel) {
1334 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1335 LoadHashGramTestModel();
1336 ASSERT_TRUE(actions_suggestions != nullptr);
1337 {
1338 const ActionsSuggestionsResponse response =
1339 actions_suggestions->SuggestActions(
1340 {{{/*user_id=*/1, "hello",
1341 /*reference_time_ms_utc=*/0,
1342 /*reference_timezone=*/"Europe/Zurich",
1343 /*annotations=*/{},
1344 /*locales=*/"en"}}});
1345 EXPECT_THAT(response.actions, testing::IsEmpty());
1346 }
1347 {
1348 const ActionsSuggestionsResponse response =
1349 actions_suggestions->SuggestActions(
1350 {{{/*user_id=*/1, "where are you",
1351 /*reference_time_ms_utc=*/0,
1352 /*reference_timezone=*/"Europe/Zurich",
1353 /*annotations=*/{},
1354 /*locales=*/"en"}}});
1355 EXPECT_THAT(
1356 response.actions,
1357 ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
1358 }
1359 {
1360 const ActionsSuggestionsResponse response =
1361 actions_suggestions->SuggestActions(
1362 {{{/*user_id=*/1, "do you know johns number",
1363 /*reference_time_ms_utc=*/0,
1364 /*reference_timezone=*/"Europe/Zurich",
1365 /*annotations=*/{},
1366 /*locales=*/"en"}}});
1367 EXPECT_THAT(
1368 response.actions,
1369 ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
1370 }
1371 }
1372
1373 // Test class to expose token embedding methods for testing.
1374 class TestingMessageEmbedder : private ActionsSuggestions {
1375 public:
1376 explicit TestingMessageEmbedder(const ActionsModel* model);
1377
1378 using ActionsSuggestions::EmbedAndFlattenTokens;
1379 using ActionsSuggestions::EmbedTokensPerMessage;
1380
1381 protected:
1382 // EmbeddingExecutor that always returns features based on
1383 // the id of the sparse features.
1384 class FakeEmbeddingExecutor : public EmbeddingExecutor {
1385 public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const1386 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
1387 const int dest_size) const override {
1388 TC3_CHECK_GE(dest_size, 1);
1389 EXPECT_EQ(sparse_features.size(), 1);
1390 dest[0] = sparse_features.data()[0];
1391 return true;
1392 }
1393 };
1394
1395 std::unique_ptr<UniLib> unilib_;
1396 };
1397
TestingMessageEmbedder(const ActionsModel * model)1398 TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
1399 : unilib_(CreateUniLibForTesting()) {
1400 model_ = model;
1401 const ActionsTokenFeatureProcessorOptions* options =
1402 model->feature_processor_options();
1403 feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
1404 embedding_executor_.reset(new FakeEmbeddingExecutor());
1405 EXPECT_TRUE(
1406 EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
1407 EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
1408 EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
1409 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
1410 EXPECT_EQ(token_embedding_size_, 1);
1411 }
1412
1413 class EmbeddingTest : public testing::Test {
1414 protected:
EmbeddingTest()1415 explicit EmbeddingTest() {
1416 model_.feature_processor_options.reset(
1417 new ActionsTokenFeatureProcessorOptionsT);
1418 options_ = model_.feature_processor_options.get();
1419 options_->chargram_orders = {1};
1420 options_->num_buckets = 1000;
1421 options_->embedding_size = 1;
1422 options_->start_token_id = 0;
1423 options_->end_token_id = 1;
1424 options_->padding_token_id = 2;
1425 options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1426 }
1427
CreateTestingMessageEmbedder()1428 TestingMessageEmbedder CreateTestingMessageEmbedder() {
1429 flatbuffers::FlatBufferBuilder builder;
1430 FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
1431 buffer_ = builder.Release();
1432 return TestingMessageEmbedder(
1433 flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
1434 }
1435
1436 flatbuffers::DetachedBuffer buffer_;
1437 ActionsModelT model_;
1438 ActionsTokenFeatureProcessorOptionsT* options_;
1439 };
1440
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithNoBounds)1441 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
1442 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1443 std::vector<std::vector<Token>> tokens = {
1444 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1445 std::vector<float> embeddings;
1446 int max_num_tokens_per_message = 0;
1447
1448 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1449 &max_num_tokens_per_message));
1450
1451 EXPECT_EQ(max_num_tokens_per_message, 3);
1452 EXPECT_EQ(embeddings.size(), 3);
1453 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1454 options_->num_buckets));
1455 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1456 options_->num_buckets));
1457 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1458 options_->num_buckets));
1459 }
1460
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithPadding)1461 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
1462 options_->min_num_tokens_per_message = 5;
1463 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1464 std::vector<std::vector<Token>> tokens = {
1465 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1466 std::vector<float> embeddings;
1467 int max_num_tokens_per_message = 0;
1468
1469 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1470 &max_num_tokens_per_message));
1471
1472 EXPECT_EQ(max_num_tokens_per_message, 5);
1473 EXPECT_EQ(embeddings.size(), 5);
1474 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1475 options_->num_buckets));
1476 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1477 options_->num_buckets));
1478 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1479 options_->num_buckets));
1480 EXPECT_THAT(embeddings[3], FloatEq(options_->padding_token_id));
1481 EXPECT_THAT(embeddings[4], FloatEq(options_->padding_token_id));
1482 }
1483
TEST_F(EmbeddingTest,EmbedsTokensPerMessageDropsAtBeginning)1484 TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
1485 options_->max_num_tokens_per_message = 2;
1486 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1487 std::vector<std::vector<Token>> tokens = {
1488 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1489 std::vector<float> embeddings;
1490 int max_num_tokens_per_message = 0;
1491
1492 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1493 &max_num_tokens_per_message));
1494
1495 EXPECT_EQ(max_num_tokens_per_message, 2);
1496 EXPECT_EQ(embeddings.size(), 2);
1497 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1498 options_->num_buckets));
1499 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1500 options_->num_buckets));
1501 }
1502
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithMultipleMessagesNoBounds)1503 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
1504 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1505 std::vector<std::vector<Token>> tokens = {
1506 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1507 {Token("d", 0, 1), Token("e", 2, 3)}};
1508 std::vector<float> embeddings;
1509 int max_num_tokens_per_message = 0;
1510
1511 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1512 &max_num_tokens_per_message));
1513
1514 EXPECT_EQ(max_num_tokens_per_message, 3);
1515 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1516 options_->num_buckets));
1517 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1518 options_->num_buckets));
1519 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1520 options_->num_buckets));
1521 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1522 options_->num_buckets));
1523 EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1524 options_->num_buckets));
1525 EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1526 }
1527
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithNoBounds)1528 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
1529 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1530 std::vector<std::vector<Token>> tokens = {
1531 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1532 std::vector<float> embeddings;
1533 int total_token_count = 0;
1534
1535 EXPECT_TRUE(
1536 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1537
1538 EXPECT_EQ(total_token_count, 5);
1539 EXPECT_EQ(embeddings.size(), 5);
1540 EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1541 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1542 options_->num_buckets));
1543 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1544 options_->num_buckets));
1545 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1546 options_->num_buckets));
1547 EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1548 }
1549
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithPadding)1550 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
1551 options_->min_num_total_tokens = 7;
1552 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1553 std::vector<std::vector<Token>> tokens = {
1554 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1555 std::vector<float> embeddings;
1556 int total_token_count = 0;
1557
1558 EXPECT_TRUE(
1559 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1560
1561 EXPECT_EQ(total_token_count, 7);
1562 EXPECT_EQ(embeddings.size(), 7);
1563 EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1564 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1565 options_->num_buckets));
1566 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1567 options_->num_buckets));
1568 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1569 options_->num_buckets));
1570 EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1571 EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1572 EXPECT_THAT(embeddings[6], FloatEq(options_->padding_token_id));
1573 }
1574
TEST_F(EmbeddingTest,EmbedsFlattenedTokensDropsAtBeginning)1575 TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
1576 options_->max_num_total_tokens = 3;
1577 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1578 std::vector<std::vector<Token>> tokens = {
1579 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1580 std::vector<float> embeddings;
1581 int total_token_count = 0;
1582
1583 EXPECT_TRUE(
1584 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1585
1586 EXPECT_EQ(total_token_count, 3);
1587 EXPECT_EQ(embeddings.size(), 3);
1588 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1589 options_->num_buckets));
1590 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1591 options_->num_buckets));
1592 EXPECT_THAT(embeddings[2], FloatEq(options_->end_token_id));
1593 }
1594
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesNoBounds)1595 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
1596 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1597 std::vector<std::vector<Token>> tokens = {
1598 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1599 {Token("d", 0, 1), Token("e", 2, 3)}};
1600 std::vector<float> embeddings;
1601 int total_token_count = 0;
1602
1603 EXPECT_TRUE(
1604 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1605
1606 EXPECT_EQ(total_token_count, 9);
1607 EXPECT_EQ(embeddings.size(), 9);
1608 EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1609 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1610 options_->num_buckets));
1611 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1612 options_->num_buckets));
1613 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1614 options_->num_buckets));
1615 EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1616 EXPECT_THAT(embeddings[5], FloatEq(options_->start_token_id));
1617 EXPECT_THAT(embeddings[6], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1618 options_->num_buckets));
1619 EXPECT_THAT(embeddings[7], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1620 options_->num_buckets));
1621 EXPECT_THAT(embeddings[8], FloatEq(options_->end_token_id));
1622 }
1623
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning)1624 TEST_F(EmbeddingTest,
1625 EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
1626 options_->max_num_total_tokens = 7;
1627 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1628 std::vector<std::vector<Token>> tokens = {
1629 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1630 {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
1631 std::vector<float> embeddings;
1632 int total_token_count = 0;
1633
1634 EXPECT_TRUE(
1635 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1636
1637 EXPECT_EQ(total_token_count, 7);
1638 EXPECT_EQ(embeddings.size(), 7);
1639 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1640 options_->num_buckets));
1641 EXPECT_THAT(embeddings[1], FloatEq(options_->end_token_id));
1642 EXPECT_THAT(embeddings[2], FloatEq(options_->start_token_id));
1643 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1644 options_->num_buckets));
1645 EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1646 options_->num_buckets));
1647 EXPECT_THAT(embeddings[5], FloatEq(tc3farmhash::Fingerprint64("f", 1) %
1648 options_->num_buckets));
1649 EXPECT_THAT(embeddings[6], FloatEq(options_->end_token_id));
1650 }
1651
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDefault)1652 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsDefault) {
1653 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1654 LoadMultiTaskTestModel();
1655 const ActionsSuggestionsResponse response =
1656 actions_suggestions->SuggestActions(
1657 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1658 /*reference_timezone=*/"Europe/Zurich",
1659 /*annotations=*/{}, /*locales=*/"en"}}});
1660 EXPECT_EQ(response.actions.size(),
1661 11 /* 8 binary classification + 3 smart replies*/);
1662 }
1663
1664 const float kDisableThresholdVal = 2.0;
1665
1666 constexpr char kSpamThreshold[] = "spam_confidence_threshold";
1667 constexpr char kLocationThreshold[] = "location_confidence_threshold";
1668 constexpr char kPhoneThreshold[] = "phone_confidence_threshold";
1669 constexpr char kWeatherThreshold[] = "weather_confidence_threshold";
1670 constexpr char kRestaurantsThreshold[] = "restaurants_confidence_threshold";
1671 constexpr char kMoviesThreshold[] = "movies_confidence_threshold";
1672 constexpr char kTtrThreshold[] = "time_to_reply_binary_threshold";
1673 constexpr char kReminderThreshold[] = "reminder_intent_confidence_threshold";
1674 constexpr char kDiversificationParm[] = "diversification_distance_threshold";
1675 constexpr char kEmpiricalProbFactor[] = "empirical_probability_factor";
1676
GetOptionsToDisableAllClassification()1677 ActionSuggestionOptions GetOptionsToDisableAllClassification() {
1678 ActionSuggestionOptions options;
1679 // Disable all classification heads.
1680 options.model_parameters.insert(
1681 {kSpamThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1682 options.model_parameters.insert(
1683 {kLocationThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1684 options.model_parameters.insert(
1685 {kPhoneThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1686 options.model_parameters.insert(
1687 {kWeatherThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1688 options.model_parameters.insert(
1689 {kRestaurantsThreshold,
1690 libtextclassifier3::Variant(kDisableThresholdVal)});
1691 options.model_parameters.insert(
1692 {kMoviesThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1693 options.model_parameters.insert(
1694 {kTtrThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1695 options.model_parameters.insert(
1696 {kReminderThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1697 return options;
1698 }
1699
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyOnly)1700 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyOnly) {
1701 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1702 LoadMultiTaskTestModel();
1703 const ActionSuggestionOptions options =
1704 GetOptionsToDisableAllClassification();
1705 const ActionsSuggestionsResponse response =
1706 actions_suggestions->SuggestActions(
1707 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1708 /*reference_timezone=*/"Europe/Zurich",
1709 /*annotations=*/{}, /*locales=*/"en"}}},
1710 /*annotator=*/nullptr, options);
1711 EXPECT_THAT(response.actions,
1712 ElementsAre(IsSmartReply("Here"), IsSmartReply("I'm here"),
1713 IsSmartReply("I'm home")));
1714 EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1715 }
1716
1717 const int kUserProfileSize = 1000;
1718 constexpr char kUserProfileTokenIndex[] = "user_profile_token_index";
1719 constexpr char kUserProfileTokenWeight[] = "user_profile_token_weight";
1720
GetOptionsForSmartReplyP13nModel()1721 ActionSuggestionOptions GetOptionsForSmartReplyP13nModel() {
1722 ActionSuggestionOptions options;
1723 const std::vector<int> user_profile_token_indexes(kUserProfileSize, 1);
1724 const std::vector<float> user_profile_token_weights(kUserProfileSize, 0.1f);
1725 options.model_parameters.insert(
1726 {kUserProfileTokenIndex,
1727 libtextclassifier3::Variant(user_profile_token_indexes)});
1728 options.model_parameters.insert(
1729 {kUserProfileTokenWeight,
1730 libtextclassifier3::Variant(user_profile_token_weights)});
1731 return options;
1732 }
1733
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyP13n)1734 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyP13n) {
1735 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1736 LoadMultiTaskSrP13nTestModel();
1737 const ActionSuggestionOptions options = GetOptionsForSmartReplyP13nModel();
1738 const ActionsSuggestionsResponse response =
1739 actions_suggestions->SuggestActions(
1740 {{{/*user_id=*/1, "How are you?", /*reference_time_ms_utc=*/0,
1741 /*reference_timezone=*/"Europe/Zurich",
1742 /*annotations=*/{}, /*locales=*/"en"}}},
1743 /*annotator=*/nullptr, options);
1744 EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1745 }
1746
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation)1747 TEST_F(ActionsSuggestionsTest,
1748 MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation) {
1749 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1750 LoadMultiTaskTestModel();
1751 ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1752 options.model_parameters[kLocationThreshold] =
1753 libtextclassifier3::Variant(0.35f);
1754 options.model_parameters.insert(
1755 {kDiversificationParm, libtextclassifier3::Variant(0.5f)});
1756 const ActionsSuggestionsResponse response =
1757 actions_suggestions->SuggestActions(
1758 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1759 /*reference_timezone=*/"Europe/Zurich",
1760 /*annotations=*/{}, /*locales=*/"en"}}},
1761 /*annotator=*/nullptr, options);
1762 EXPECT_THAT(
1763 response.actions,
1764 ElementsAre(IsActionOfType("LOCATION_SHARE"), IsSmartReply("Here"),
1765 IsSmartReply("Yes"), IsSmartReply("")));
1766 EXPECT_EQ(response.actions.size(), 4 /*1 location share + 3 smart replies*/);
1767 }
1768
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder)1769 TEST_F(ActionsSuggestionsTest,
1770 MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder) {
1771 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1772 LoadMultiTaskTestModel();
1773 ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1774 options.model_parameters[kLocationThreshold] =
1775 libtextclassifier3::Variant(0.35f);
1776 // reminder head always trigger since the threshold is zero.
1777 options.model_parameters[kReminderThreshold] =
1778 libtextclassifier3::Variant(0.0f);
1779 options.model_parameters.insert(
1780 {kEmpiricalProbFactor, libtextclassifier3::Variant(2.0f)});
1781 const ActionsSuggestionsResponse response =
1782 actions_suggestions->SuggestActions(
1783 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1784 /*reference_timezone=*/"Europe/Zurich",
1785 /*annotations=*/{}, /*locales=*/"en"}}},
1786 /*annotator=*/nullptr, options);
1787 EXPECT_THAT(
1788 response.actions,
1789 ElementsAre(IsSmartReply("Okay"), IsActionOfType("LOCATION_SHARE"),
1790 IsSmartReply("Yes"),
1791 /*Different emoji than previous test*/ IsSmartReply(""),
1792 IsActionOfType("REMINDER_INTENT")));
1793 EXPECT_EQ(response.actions.size(), 5 /*1 location share + 3 smart replies*/);
1794 }
1795
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromMultiTaskSrEmojiModel)1796 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
1797 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1798 LoadTestModel(kMultiTaskSrEmojiModelFileName);
1799 const ActionsSuggestionsResponse response =
1800 actions_suggestions->SuggestActions(
1801 {{{/*user_id=*/1, "hello?",
1802 /*reference_time_ms_utc=*/0,
1803 /*reference_timezone=*/"Europe/Zurich",
1804 /*annotations=*/{},
1805 /*locales=*/"en"}}});
1806 EXPECT_EQ(response.actions.size(), 5);
1807 EXPECT_EQ(response.actions[0].response_text, "");
1808 EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT");
1809 EXPECT_EQ(response.actions[1].response_text, "Yes");
1810 EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
1811 }
1812
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromSensitiveTfLiteModel)1813 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
1814 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1815 LoadTestModel(kSensitiveTFliteModelFileName);
1816 const ActionsSuggestionsResponse response =
1817 actions_suggestions->SuggestActions(
1818 {{{/*user_id=*/1, "I want to kill myself",
1819 /*reference_time_ms_utc=*/0,
1820 /*reference_timezone=*/"Europe/Zurich",
1821 /*annotations=*/{},
1822 /*locales=*/"en"}}});
1823 EXPECT_EQ(response.actions.size(), 0);
1824 EXPECT_TRUE(response.is_sensitive);
1825 EXPECT_FALSE(response.output_filtered_low_confidence);
1826 }
1827
1828 } // namespace
1829 } // namespace libtextclassifier3
1830