1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_
18 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "flatbuffers/flexbuffers.h"
22 #include "tensorflow/lite/context.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace custom {
27 namespace libtextclassifier3 {
28 namespace blacklist {
29 
30 /*
31  * A framework for writing ops that generates prediction vectors using a
32  * blacklist.
33  *
34  * Input is defined by the specific implementation.
35  *
36  * Attributes:
37  *   blacklist:           string[n]
38  *     Terms in the blacklist.
39  *   blacklist_category:  int[n]
40  *     Category for each term in the blacklist.  Each category must be in
41  *     [0, categories).
42  *   categories:          int[]
43  *     Total number of categories.
44  *   negative_categories: int[]
45  *     Total number of negative categories.
46  *
47  * Output:
48  *   tensor[0]: Category indicators for each message, float[..., categories]
49  *
50  */
51 
52 class BlacklistOpBase {
53  public:
BlacklistOpBase(const flexbuffers::Map & custom_options)54   explicit BlacklistOpBase(const flexbuffers::Map& custom_options)
55       : categories_(custom_options["categories"].AsInt32()),
56         negative_categories_(custom_options["negative_categories"].AsInt32()) {}
57 
~BlacklistOpBase()58   virtual ~BlacklistOpBase() {}
59 
categories()60   int categories() const { return categories_; }
negative_categories()61   int negative_categories() const { return negative_categories_; }
62 
63   virtual TfLiteStatus InitializeInput(TfLiteContext* context,
64                                        TfLiteNode* node) = 0;
65   virtual absl::flat_hash_set<int> GetCategories(int i) const = 0;
66   virtual void FinalizeInput() = 0;
67 
68   // Returns the input shape.  TfLiteIntArray is owned by the object.
69   virtual TfLiteIntArray* GetInputShape(TfLiteContext* context,
70                                         TfLiteNode* node) = 0;
71 
72  private:
73   int categories_;
74   int negative_categories_;
75 };
76 
77 // Individual ops should define an Init() function that returns a
78 // BlacklistOpBase.
79 
80 void Free(TfLiteContext* context, void* buffer);
81 
82 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node);
83 
84 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node);
85 }  // namespace blacklist
86 }  // namespace libtextclassifier3
87 }  // namespace custom
88 }  // namespace ops
89 }  // namespace tflite
90 
91 #endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_
92