1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/tflite-sensitive-model.h"
18 
19 #include <utility>
20 
21 #include "actions/actions_model_generated.h"
22 #include "actions/types.h"
23 
24 namespace libtextclassifier3 {
25 namespace {
26 const char kNotSensitive[] = "NOT_SENSITIVE";
27 }  // namespace
28 
Create(const TFLiteSensitiveClassifierConfig * model_config)29 std::unique_ptr<TFLiteSensitiveModel> TFLiteSensitiveModel::Create(
30     const TFLiteSensitiveClassifierConfig* model_config) {
31   auto result_model = std::unique_ptr<TFLiteSensitiveModel>(
32       new TFLiteSensitiveModel(model_config));
33   if (result_model->model_executor_ == nullptr) {
34     return nullptr;
35   }
36   return result_model;
37 }
38 
Eval(const UnicodeText & text) const39 std::pair<bool, float> TFLiteSensitiveModel::Eval(
40     const UnicodeText& text) const {
41   // Create a conversation with one message and classify it.
42   Conversation conversation;
43   conversation.messages.emplace_back();
44   conversation.messages.front().text = text.ToUTF8String();
45 
46   return EvalConversation(conversation, 1);
47 }
48 
EvalConversation(const Conversation & conversation,int num_messages) const49 std::pair<bool, float> TFLiteSensitiveModel::EvalConversation(
50     const Conversation& conversation, int num_messages) const {
51   if (model_executor_ == nullptr) {
52     return std::make_pair(false, 0.0f);
53   }
54   const auto interpreter = model_executor_->CreateInterpreter();
55 
56   if (interpreter->AllocateTensors() != kTfLiteOk) {
57     // TODO(mgubin):  report error that tensors can't be allocated.
58     return std::make_pair(false, 0.0f);
59   }
60   // The sensitive model is actually an ordinary TFLite model with Lingua API,
61   // prepare texts and user_ids similar way, it doesn't use timediffs.
62   std::vector<std::string> context;
63   std::vector<int> user_ids;
64   context.reserve(num_messages);
65   user_ids.reserve(num_messages);
66 
67   // Gather last `num_messages` messages from the conversation.
68   for (int i = conversation.messages.size() - num_messages;
69        i < conversation.messages.size(); i++) {
70     const ConversationMessage& message = conversation.messages[i];
71     context.push_back(message.text);
72     user_ids.push_back(message.user_id);
73   }
74 
75   // Allocate tensors.
76   //
77 
78   if (model_config_->model_spec()->input_context() >= 0) {
79     if (model_config_->model_spec()->input_length_to_pad() > 0) {
80       context.resize(model_config_->model_spec()->input_length_to_pad());
81     }
82     model_executor_->SetInput<std::string>(
83         model_config_->model_spec()->input_context(), context,
84         interpreter.get());
85   }
86   if (model_config_->model_spec()->input_context_length() >= 0) {
87     model_executor_->SetInput<int>(
88         model_config_->model_spec()->input_context_length(), context.size(),
89         interpreter.get());
90   }
91 
92   // Num suggestions is always locked to 3.
93   if (model_config_->model_spec()->input_num_suggestions() > 0) {
94     model_executor_->SetInput<int>(
95         model_config_->model_spec()->input_num_suggestions(), 3,
96         interpreter.get());
97   }
98 
99   if (interpreter->Invoke() != kTfLiteOk) {
100     // TODO(mgubin): Report a error about invoke.
101     return std::make_pair(false, 0.0f);
102   }
103 
104   // Check that the prediction is not-sensitive.
105   const std::vector<tflite::StringRef> replies =
106       model_executor_->Output<tflite::StringRef>(
107           model_config_->model_spec()->output_replies(), interpreter.get());
108   const TensorView<float> scores = model_executor_->OutputView<float>(
109       model_config_->model_spec()->output_replies_scores(), interpreter.get());
110   for (int i = 0; i < replies.size(); ++i) {
111     const auto reply = replies[i];
112     if (reply.len != sizeof(kNotSensitive) - 1 &&
113         0 != memcmp(reply.str, kNotSensitive, sizeof(kNotSensitive))) {
114       const auto score = scores.data()[i];
115       if (score >= model_config_->threshold()) {
116         return std::make_pair(true, score);
117       }
118     }
119   }
120   return std::make_pair(false, 1.0);
121 }
122 
TFLiteSensitiveModel(const TFLiteSensitiveClassifierConfig * model_config)123 TFLiteSensitiveModel::TFLiteSensitiveModel(
124     const TFLiteSensitiveClassifierConfig* model_config)
125     : model_config_(model_config),
126       model_executor_(TfLiteModelExecutor::FromBuffer(
127           model_config->model_spec()->tflite_model())) {}
128 }  // namespace libtextclassifier3
129