1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
18 #define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
19 
20 #include <fstream>
21 #include <string>
22 #include <vector>
23 
24 #include "utils/wordpiece_tokenizer.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
27 #include "tensorflow_lite_support/cc/utils/common_utils.h"
28 
29 namespace libtextclassifier3 {
30 
31 using ::tflite::support::text::tokenizer::TokenizerResult;
32 using ::tflite::support::utils::LoadVocabFromBuffer;
33 using ::tflite::support::utils::LoadVocabFromFile;
34 
35 constexpr int kDefaultMaxBytesPerToken = 100;
36 constexpr int kDefaultMaxCharsPerSubToken = 100;
37 constexpr char kDefaultSuffixIndicator[] = "##";
38 constexpr bool kDefaultUseUnknownToken = true;
39 constexpr char kDefaultUnknownToken[] = "[UNK]";
40 constexpr bool kDefaultSplitUnknownChars = false;
41 
42 // Result of wordpiece tokenization including subwords and offsets.
43 // Example:
44 // input:                tokenize     me  please
45 // subwords:             token ##ize  me  plea ##se
46 // wp_begin_offset:     [0,      5,   9,  12,    16]
47 // wp_end_offset:       [     5,    8,  11,   16,  18]
48 // row_lengths:         [2,          1,  1]
49 struct WordpieceTokenizerResult
50     : tflite::support::text::tokenizer::TokenizerResult {
51   std::vector<int> wp_begin_offset;
52   std::vector<int> wp_end_offset;
53   std::vector<int> row_lengths;
54 };
55 
56 // Options to create a BertTokenizer.
57 struct BertTokenizerOptions {
58   int max_bytes_per_token = kDefaultMaxBytesPerToken;
59   int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken;
60   std::string suffix_indicator = kDefaultSuffixIndicator;
61   bool use_unknown_token = kDefaultUseUnknownToken;
62   std::string unknown_token = kDefaultUnknownToken;
63   bool split_unknown_chars = kDefaultSplitUnknownChars;
64 };
65 
66 // A flat-hash-map based implementation of WordpieceVocab, used in
67 // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within.
68 class FlatHashMapBackedWordpiece : public WordpieceVocab {
69  public:
70   explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab);
71 
72   LookupStatus Contains(absl::string_view key, bool* value) const override;
73   bool LookupId(absl::string_view key, int* result) const;
74   bool LookupWord(int vocab_id, absl::string_view* result) const;
VocabularySize()75   int VocabularySize() const { return vocab_.size(); }
76 
77  private:
78   // All words indexed position in vocabulary file.
79   std::vector<std::string> vocab_;
80   absl::flat_hash_map<absl::string_view, int> index_map_;
81 };
82 
83 // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector.
84 class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
85  public:
86   // Initialize the tokenizer from vocab vector and tokenizer configs.
87   explicit BertTokenizer(const std::vector<std::string>& vocab,
88                          const BertTokenizerOptions& options = {})
89       : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {}
90 
91   // Initialize the tokenizer from file path to vocab and tokenizer configs.
92   explicit BertTokenizer(const std::string& path_to_vocab,
93                          const BertTokenizerOptions& options = {})
BertTokenizer(LoadVocabFromFile (path_to_vocab),options)94       : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {}
95 
96   // Initialize the tokenizer from buffer and size of vocab and tokenizer
97   // configs.
98   BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
99                 const BertTokenizerOptions& options = {})
BertTokenizer(LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)100       : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
101                       options) {}
102 
103   // Perform tokenization, first tokenize the input and then find the subwords.
104   // return tokenized results containing the subwords.
105   TokenizerResult Tokenize(const std::string& input) override;
106 
107   // Perform tokenization on a single token, return tokenized results containing
108   // the subwords.
109   TokenizerResult TokenizeSingleToken(const std::string& token);
110 
111   // Perform tokenization, return tokenized results containing the subwords.
112   TokenizerResult Tokenize(const std::vector<std::string>& tokens);
113 
114   // Check if a certain key is included in the vocab.
Contains(const absl::string_view key,bool * value)115   LookupStatus Contains(const absl::string_view key, bool* value) const {
116     return vocab_.Contains(key, value);
117   }
118 
119   // Find the id of a wordpiece.
LookupId(absl::string_view key,int * result)120   bool LookupId(absl::string_view key, int* result) const override {
121     return vocab_.LookupId(key, result);
122   }
123 
124   // Find the wordpiece from an id.
LookupWord(int vocab_id,absl::string_view * result)125   bool LookupWord(int vocab_id, absl::string_view* result) const override {
126     return vocab_.LookupWord(vocab_id, result);
127   }
128 
VocabularySize()129   int VocabularySize() const { return vocab_.VocabularySize(); }
130 
131   static std::vector<std::string> PreTokenize(const absl::string_view input);
132 
133  private:
134   FlatHashMapBackedWordpiece vocab_;
135   BertTokenizerOptions options_;
136 };
137 
138 }  // namespace libtextclassifier3
139 
140 #endif  // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
141