1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 19 20 #include <fstream> 21 #include <string> 22 #include <vector> 23 24 #include "utils/wordpiece_tokenizer.h" 25 #include "absl/container/flat_hash_map.h" 26 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" 27 #include "tensorflow_lite_support/cc/utils/common_utils.h" 28 29 namespace libtextclassifier3 { 30 31 using ::tflite::support::text::tokenizer::TokenizerResult; 32 using ::tflite::support::utils::LoadVocabFromBuffer; 33 using ::tflite::support::utils::LoadVocabFromFile; 34 35 constexpr int kDefaultMaxBytesPerToken = 100; 36 constexpr int kDefaultMaxCharsPerSubToken = 100; 37 constexpr char kDefaultSuffixIndicator[] = "##"; 38 constexpr bool kDefaultUseUnknownToken = true; 39 constexpr char kDefaultUnknownToken[] = "[UNK]"; 40 constexpr bool kDefaultSplitUnknownChars = false; 41 42 // Result of wordpiece tokenization including subwords and offsets. 43 // Example: 44 // input: tokenize me please 45 // subwords: token ##ize me plea ##se 46 // wp_begin_offset: [0, 5, 9, 12, 16] 47 // wp_end_offset: [ 5, 8, 11, 16, 18] 48 // row_lengths: [2, 1, 1] 49 struct WordpieceTokenizerResult 50 : tflite::support::text::tokenizer::TokenizerResult { 51 std::vector<int> wp_begin_offset; 52 std::vector<int> wp_end_offset; 53 std::vector<int> row_lengths; 54 }; 55 56 // Options to create a BertTokenizer. 57 struct BertTokenizerOptions { 58 int max_bytes_per_token = kDefaultMaxBytesPerToken; 59 int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken; 60 std::string suffix_indicator = kDefaultSuffixIndicator; 61 bool use_unknown_token = kDefaultUseUnknownToken; 62 std::string unknown_token = kDefaultUnknownToken; 63 bool split_unknown_chars = kDefaultSplitUnknownChars; 64 }; 65 66 // A flat-hash-map based implementation of WordpieceVocab, used in 67 // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within. 68 class FlatHashMapBackedWordpiece : public WordpieceVocab { 69 public: 70 explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab); 71 72 LookupStatus Contains(absl::string_view key, bool* value) const override; 73 bool LookupId(absl::string_view key, int* result) const; 74 bool LookupWord(int vocab_id, absl::string_view* result) const; VocabularySize()75 int VocabularySize() const { return vocab_.size(); } 76 77 private: 78 // All words indexed position in vocabulary file. 79 std::vector<std::string> vocab_; 80 absl::flat_hash_map<absl::string_view, int> index_map_; 81 }; 82 83 // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. 84 class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { 85 public: 86 // Initialize the tokenizer from vocab vector and tokenizer configs. 87 explicit BertTokenizer(const std::vector<std::string>& vocab, 88 const BertTokenizerOptions& options = {}) 89 : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {} 90 91 // Initialize the tokenizer from file path to vocab and tokenizer configs. 92 explicit BertTokenizer(const std::string& path_to_vocab, 93 const BertTokenizerOptions& options = {}) BertTokenizer(LoadVocabFromFile (path_to_vocab),options)94 : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {} 95 96 // Initialize the tokenizer from buffer and size of vocab and tokenizer 97 // configs. 98 BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, 99 const BertTokenizerOptions& options = {}) BertTokenizer(LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)100 : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), 101 options) {} 102 103 // Perform tokenization, first tokenize the input and then find the subwords. 104 // return tokenized results containing the subwords. 105 TokenizerResult Tokenize(const std::string& input) override; 106 107 // Perform tokenization on a single token, return tokenized results containing 108 // the subwords. 109 TokenizerResult TokenizeSingleToken(const std::string& token); 110 111 // Perform tokenization, return tokenized results containing the subwords. 112 TokenizerResult Tokenize(const std::vector<std::string>& tokens); 113 114 // Check if a certain key is included in the vocab. Contains(const absl::string_view key,bool * value)115 LookupStatus Contains(const absl::string_view key, bool* value) const { 116 return vocab_.Contains(key, value); 117 } 118 119 // Find the id of a wordpiece. LookupId(absl::string_view key,int * result)120 bool LookupId(absl::string_view key, int* result) const override { 121 return vocab_.LookupId(key, result); 122 } 123 124 // Find the wordpiece from an id. LookupWord(int vocab_id,absl::string_view * result)125 bool LookupWord(int vocab_id, absl::string_view* result) const override { 126 return vocab_.LookupWord(vocab_id, result); 127 } 128 VocabularySize()129 int VocabularySize() const { return vocab_.VocabularySize(); } 130 131 static std::vector<std::string> PreTokenize(const absl::string_view input); 132 133 private: 134 FlatHashMapBackedWordpiece vocab_; 135 BertTokenizerOptions options_; 136 }; 137 138 } // namespace libtextclassifier3 139 140 #endif // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 141