1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
18 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
19 
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/lite/string_util.h"
27 
28 namespace libtextclassifier3 {
29 
30 // SkipgramFinder finds skipgrams in strings.
31 //
32 // To use: First, add skipgrams using AddSkipgram() - each skipgram is
33 // associated with some category.  Then, call FindSkipgrams() on a string,
34 // which will return the set of categories of the skipgrams in the string.
35 //
36 // Both the skipgrams and the input strings will be tokenzied by splitting
37 // on spaces.  Additionally, the tokens will be lowercased and have any
38 // trailing punctuation removed.
39 class SkipgramFinder {
40  public:
SkipgramFinder(int max_skip_size)41   explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {}
42 
43   // Adds a skipgram that SkipgramFinder should look for in input strings.
44   // Tokens may use the regex '.*' as a suffix.
45   void AddSkipgram(const std::string& skipgram, int category);
46 
47   // Find all of the skipgrams in `input`, and return their categories.
48   absl::flat_hash_set<int> FindSkipgrams(const std::string& input) const;
49 
50   // Find all of the skipgrams in `tokens`, and return their categories.
51   absl::flat_hash_set<int> FindSkipgrams(
52       const std::vector<absl::string_view>& tokens) const;
53   absl::flat_hash_set<int> FindSkipgrams(
54       const std::vector<::tflite::StringRef>& tokens) const;
55 
56  private:
57   struct TrieNode {
58     absl::flat_hash_set<int> categories;
59     // Maps tokens to the next node in the trie.
60     absl::flat_hash_map<std::string, TrieNode> token_to_node;
61     // Maps token prefixes (<prefix>.*) to the next node in the trie.
62     absl::flat_hash_map<std::string, TrieNode> prefix_to_node;
63   };
64 
65   TrieNode skipgram_trie_;
66   int max_skip_size_;
67 };
68 
69 }  // namespace libtextclassifier3
70 #endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
71