1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ 18 19 #include <fstream> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "re2/re2.h" 25 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" 26 #include "tensorflow_lite_support/cc/utils/common_utils.h" 27 #include "tensorflow_text/core/kernels/regex_split.h" 28 #include "tensorflow_text/core/kernels/wordpiece_tokenizer.h" 29 30 namespace tflite { 31 namespace support { 32 namespace text { 33 namespace tokenizer { 34 35 constexpr char kDefaultDelimRe[] = 36 R"((\s+|[!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))"; 37 constexpr char kDefaultIncludeDelimRe[] = 38 R"(([!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))"; 39 constexpr int kDefaultMaxBytesPerToken = 100; 40 constexpr int kDefaultMaxCharsPerSubToken = 100; 41 constexpr char kDefaultSuffixIndicator[] = "##"; 42 constexpr bool kDefaultUseUnknownToken = true; 43 constexpr char kDefaultUnknownToken[] = "[UNK]"; 44 constexpr bool kDefaultSplitUnknownChars = false; 45 46 // Result of wordpiece tokenization including subwords and offsets. 47 // Example: 48 // input: tokenize me please 49 // subwords: token ##ize me plea ##se 50 // wp_begin_offset: [0, 5, 9, 12, 16] 51 // wp_end_offset: [ 5, 8, 11, 16, 18] 52 // row_lengths: [2, 1, 1] 53 struct WordpieceTokenizerResult : TokenizerResult { 54 std::vector<int> wp_begin_offset; 55 std::vector<int> wp_end_offset; 56 std::vector<int> row_lengths; 57 }; 58 // Options to create a BertTokenizer. 59 struct BertTokenizerOptions { 60 int max_bytes_per_token = kDefaultMaxBytesPerToken; 61 int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken; 62 std::string suffix_indicator = kDefaultSuffixIndicator; 63 bool use_unknown_token = kDefaultUseUnknownToken; 64 std::string unknown_token = kDefaultUnknownToken; 65 bool split_unknown_chars = kDefaultSplitUnknownChars; 66 std::string delim_str = kDefaultDelimRe; 67 std::string include_delim_str = kDefaultIncludeDelimRe; 68 }; 69 70 // A flat-hash-map based implementation of WordpieceVocab, used in 71 // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within. 72 class FlatHashMapBackedWordpiece : public tensorflow::text::WordpieceVocab { 73 public: 74 explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab); 75 76 tensorflow::text::LookupStatus Contains(absl::string_view key, 77 bool* value) const override; 78 bool LookupId(absl::string_view key, int* result) const; 79 bool LookupWord(int vocab_id, absl::string_view* result) const; VocabularySize()80 int VocabularySize() const { return vocab_.size(); } 81 82 private: 83 // All words indexed position in vocabulary file. 84 std::vector<std::string> vocab_; 85 absl::flat_hash_map<absl::string_view, int> index_map_; 86 }; 87 88 // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. 89 class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { 90 public: 91 // Initialize the tokenizer from vocab vector and tokenizer configs. 92 explicit BertTokenizer(const std::vector<std::string>& vocab, 93 const BertTokenizerOptions& options = {}) 94 : vocab_{FlatHashMapBackedWordpiece(vocab)}, 95 options_{options}, 96 delim_re_{options.delim_str}, 97 include_delim_re_{options.include_delim_str} {} 98 99 // Initialize the tokenizer from file path to vocab and tokenizer configs. 100 explicit BertTokenizer(const std::string& path_to_vocab, 101 const BertTokenizerOptions& options = {}) BertTokenizer(utils::LoadVocabFromFile (path_to_vocab),options)102 : BertTokenizer(utils::LoadVocabFromFile(path_to_vocab), options) {} 103 104 // Initialize the tokenizer from buffer and size of vocab and tokenizer 105 // configs. 106 BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, 107 const BertTokenizerOptions& options = {}) BertTokenizer(utils::LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)108 : BertTokenizer( 109 utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), 110 options) {} 111 112 // Perform tokenization, return tokenized results containing the subwords. 113 TokenizerResult Tokenize(const std::string& input) override; 114 115 // Perform tokenization, return wordpiece-specific tokenized result including 116 // subwords and offsets 117 WordpieceTokenizerResult TokenizeWordpiece(const std::string& input); 118 119 // Check if a certain key is included in the vocab. Contains(const absl::string_view key,bool * value)120 tensorflow::text::LookupStatus Contains(const absl::string_view key, 121 bool* value) const { 122 return vocab_.Contains(key, value); 123 } 124 125 // Find the id of a wordpiece. LookupId(absl::string_view key,int * result)126 bool LookupId(absl::string_view key, int* result) const override { 127 return vocab_.LookupId(key, result); 128 } 129 130 // Find the wordpiece from an id. LookupWord(int vocab_id,absl::string_view * result)131 bool LookupWord(int vocab_id, absl::string_view* result) const override { 132 return vocab_.LookupWord(vocab_id, result); 133 } 134 VocabularySize()135 int VocabularySize() const { return vocab_.VocabularySize(); } 136 137 private: 138 tflite::support::text::tokenizer::FlatHashMapBackedWordpiece vocab_; 139 BertTokenizerOptions options_; 140 RE2 delim_re_; 141 RE2 include_delim_re_; 142 }; 143 144 } // namespace tokenizer 145 } // namespace text 146 } // namespace support 147 } // namespace tflite 148 149 #endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ 150