1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/blacklist.h"
18 
19 #include "utils/tflite/blacklist_base.h"
20 #include "utils/tflite/skipgram_finder.h"
21 #include "flatbuffers/flexbuffers.h"
22 
23 namespace tflite {
24 namespace ops {
25 namespace custom {
26 
27 namespace libtextclassifier3 {
28 namespace blacklist {
29 
30 // Generates prediction vectors for input strings using a skipgram blacklist.
31 // This uses the framework in `blacklist_base.h`, with the implementation detail
32 // that the input is a string tensor of messages and the terms are skipgrams.
33 class BlacklistOp : public BlacklistOpBase {
34  public:
BlacklistOp(const flexbuffers::Map & custom_options)35   explicit BlacklistOp(const flexbuffers::Map& custom_options)
36       : BlacklistOpBase(custom_options),
37         skipgram_finder_(custom_options["max_skip_size"].AsInt32()),
38         input_(nullptr) {
39     auto blacklist = custom_options["blacklist"].AsTypedVector();
40     auto blacklist_category =
41         custom_options["blacklist_category"].AsTypedVector();
42     for (int i = 0; i < blacklist.size(); i++) {
43       int category = blacklist_category[i].AsInt32();
44       flexbuffers::String s = blacklist[i].AsString();
45       skipgram_finder_.AddSkipgram(std::string(s.c_str(), s.length()),
46                                    category);
47     }
48   }
49 
InitializeInput(TfLiteContext * context,TfLiteNode * node)50   TfLiteStatus InitializeInput(TfLiteContext* context,
51                                TfLiteNode* node) override {
52     input_ = &context->tensors[node->inputs->data[kInputMessage]];
53     return kTfLiteOk;
54   }
55 
GetCategories(int i) const56   absl::flat_hash_set<int> GetCategories(int i) const override {
57     StringRef input = GetString(input_, i);
58     return skipgram_finder_.FindSkipgrams(std::string(input.str, input.len));
59   }
60 
FinalizeInput()61   void FinalizeInput() override { input_ = nullptr; }
62 
GetInputShape(TfLiteContext * context,TfLiteNode * node)63   TfLiteIntArray* GetInputShape(TfLiteContext* context,
64                                 TfLiteNode* node) override {
65     return context->tensors[node->inputs->data[kInputMessage]].dims;
66   }
67 
68  private:
69   ::libtextclassifier3::SkipgramFinder skipgram_finder_;
70   TfLiteTensor* input_;
71 
72   static constexpr int kInputMessage = 0;
73 };
74 
BlacklistOpInit(TfLiteContext * context,const char * buffer,size_t length)75 void* BlacklistOpInit(TfLiteContext* context, const char* buffer,
76                       size_t length) {
77   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
78   return new BlacklistOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
79 }
80 
81 }  // namespace blacklist
82 
Register_BLACKLIST()83 TfLiteRegistration* Register_BLACKLIST() {
84   static TfLiteRegistration r = {libtextclassifier3::blacklist::BlacklistOpInit,
85                                  libtextclassifier3::blacklist::Free,
86                                  libtextclassifier3::blacklist::Resize,
87                                  libtextclassifier3::blacklist::Eval};
88   return &r;
89 }
90 
91 }  // namespace libtextclassifier3
92 }  // namespace custom
93 }  // namespace ops
94 }  // namespace tflite
95