1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_
18 
19 #include <fstream>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "re2/re2.h"
25 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
26 #include "tensorflow_lite_support/cc/utils/common_utils.h"
27 #include "tensorflow_text/core/kernels/regex_split.h"
28 #include "tensorflow_text/core/kernels/wordpiece_tokenizer.h"
29 
30 namespace tflite {
31 namespace support {
32 namespace text {
33 namespace tokenizer {
34 
35 constexpr char kDefaultDelimRe[] =
36     R"((\s+|[!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))";
37 constexpr char kDefaultIncludeDelimRe[] =
38     R"(([!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))";
39 constexpr int kDefaultMaxBytesPerToken = 100;
40 constexpr int kDefaultMaxCharsPerSubToken = 100;
41 constexpr char kDefaultSuffixIndicator[] = "##";
42 constexpr bool kDefaultUseUnknownToken = true;
43 constexpr char kDefaultUnknownToken[] = "[UNK]";
44 constexpr bool kDefaultSplitUnknownChars = false;
45 
46 // Result of wordpiece tokenization including subwords and offsets.
47 // Example:
48 // input:                tokenize     me  please
49 // subwords:             token ##ize  me  plea ##se
50 // wp_begin_offset:     [0,      5,   9,  12,    16]
51 // wp_end_offset:       [     5,    8,  11,   16,  18]
52 // row_lengths:         [2,          1,  1]
53 struct WordpieceTokenizerResult : TokenizerResult {
54   std::vector<int> wp_begin_offset;
55   std::vector<int> wp_end_offset;
56   std::vector<int> row_lengths;
57 };
58 // Options to create a BertTokenizer.
59 struct BertTokenizerOptions {
60   int max_bytes_per_token = kDefaultMaxBytesPerToken;
61   int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken;
62   std::string suffix_indicator = kDefaultSuffixIndicator;
63   bool use_unknown_token = kDefaultUseUnknownToken;
64   std::string unknown_token = kDefaultUnknownToken;
65   bool split_unknown_chars = kDefaultSplitUnknownChars;
66   std::string delim_str = kDefaultDelimRe;
67   std::string include_delim_str = kDefaultIncludeDelimRe;
68 };
69 
70 // A flat-hash-map based implementation of WordpieceVocab, used in
71 // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within.
72 class FlatHashMapBackedWordpiece : public tensorflow::text::WordpieceVocab {
73  public:
74   explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab);
75 
76   tensorflow::text::LookupStatus Contains(absl::string_view key,
77                                           bool* value) const override;
78   bool LookupId(absl::string_view key, int* result) const;
79   bool LookupWord(int vocab_id, absl::string_view* result) const;
VocabularySize()80   int VocabularySize() const { return vocab_.size(); }
81 
82  private:
83   // All words indexed position in vocabulary file.
84   std::vector<std::string> vocab_;
85   absl::flat_hash_map<absl::string_view, int> index_map_;
86 };
87 
88 // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector.
89 class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
90  public:
91   // Initialize the tokenizer from vocab vector and tokenizer configs.
92   explicit BertTokenizer(const std::vector<std::string>& vocab,
93                          const BertTokenizerOptions& options = {})
94       : vocab_{FlatHashMapBackedWordpiece(vocab)},
95         options_{options},
96         delim_re_{options.delim_str},
97         include_delim_re_{options.include_delim_str} {}
98 
99   // Initialize the tokenizer from file path to vocab and tokenizer configs.
100   explicit BertTokenizer(const std::string& path_to_vocab,
101                          const BertTokenizerOptions& options = {})
BertTokenizer(utils::LoadVocabFromFile (path_to_vocab),options)102       : BertTokenizer(utils::LoadVocabFromFile(path_to_vocab), options) {}
103 
104   // Initialize the tokenizer from buffer and size of vocab and tokenizer
105   // configs.
106   BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
107                 const BertTokenizerOptions& options = {})
BertTokenizer(utils::LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)108       : BertTokenizer(
109             utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
110             options) {}
111 
112   // Perform tokenization, return tokenized results containing the subwords.
113   TokenizerResult Tokenize(const std::string& input) override;
114 
115   // Perform tokenization, return wordpiece-specific tokenized result including
116   // subwords and offsets
117   WordpieceTokenizerResult TokenizeWordpiece(const std::string& input);
118 
119   // Check if a certain key is included in the vocab.
Contains(const absl::string_view key,bool * value)120   tensorflow::text::LookupStatus Contains(const absl::string_view key,
121                                           bool* value) const {
122     return vocab_.Contains(key, value);
123   }
124 
125   // Find the id of a wordpiece.
LookupId(absl::string_view key,int * result)126   bool LookupId(absl::string_view key, int* result) const override {
127     return vocab_.LookupId(key, result);
128   }
129 
130   // Find the wordpiece from an id.
LookupWord(int vocab_id,absl::string_view * result)131   bool LookupWord(int vocab_id, absl::string_view* result) const override {
132     return vocab_.LookupWord(vocab_id, result);
133   }
134 
VocabularySize()135   int VocabularySize() const { return vocab_.VocabularySize(); }
136 
137  private:
138   tflite::support::text::tokenizer::FlatHashMapBackedWordpiece vocab_;
139   BertTokenizerOptions options_;
140   RE2 delim_re_;
141   RE2 include_delim_re_;
142 };
143 
144 }  // namespace tokenizer
145 }  // namespace text
146 }  // namespace support
147 }  // namespace tflite
148 
149 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_
150