1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_TFLITE_SENSITIVE_MODEL_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_TFLITE_SENSITIVE_MODEL_H_
19 
20 #include <memory>
21 
22 #include "actions/actions_model_generated.h"
23 #include "actions/sensitive-classifier-base.h"
24 #include "utils/tflite-model-executor.h"
25 
26 namespace libtextclassifier3 {
27 class TFLiteSensitiveModel : public SensitiveTopicModelBase {
28  public:
29   // The object keeps but doesn't own model_config.
30   static std::unique_ptr<TFLiteSensitiveModel> Create(
31       const TFLiteSensitiveClassifierConfig* model_config);
32 
33   std::pair<bool, float> Eval(const UnicodeText& text) const override;
34   std::pair<bool, float> EvalConversation(const Conversation& conversation,
35                                           int num_messages) const override;
36 
37  private:
38   explicit TFLiteSensitiveModel(
39       const TFLiteSensitiveClassifierConfig* model_config);
40   const TFLiteSensitiveClassifierConfig* model_config_ = nullptr;  // not owned.
41   std::unique_ptr<const TfLiteModelExecutor> model_executor_;
42 };
43 }  // namespace libtextclassifier3
44 
45 #endif  // LIBTEXTCLASSIFIER_ACTIONS_TFLITE_SENSITIVE_MODEL_H_
46