1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/bert_tokenizer.h"
18 
19 #include <string>
20 
21 #include "annotator/types.h"
22 #include "utils/tokenizer-utils.h"
23 #include "utils/utf8/unicodetext.h"
24 #include "utils/utf8/unilib.h"
25 #include "absl/strings/string_view.h"
26 
27 namespace libtextclassifier3 {
28 
FlatHashMapBackedWordpiece(const std::vector<std::string> & vocab)29 FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece(
30     const std::vector<std::string>& vocab)
31     : vocab_{vocab} {
32   for (int i = 0; i < vocab_.size(); ++i) {
33     index_map_[vocab_[i]] = i;
34   }
35 }
36 
Contains(absl::string_view key,bool * value) const37 LookupStatus FlatHashMapBackedWordpiece::Contains(absl::string_view key,
38                                                   bool* value) const {
39   *value = index_map_.contains(key);
40   return LookupStatus();
41 }
42 
LookupId(const absl::string_view key,int * result) const43 bool FlatHashMapBackedWordpiece::LookupId(const absl::string_view key,
44                                           int* result) const {
45   auto it = index_map_.find(key);
46   if (it == index_map_.end()) {
47     return false;
48   }
49   *result = it->second;
50   return true;
51 }
52 
LookupWord(int vocab_id,absl::string_view * result) const53 bool FlatHashMapBackedWordpiece::LookupWord(int vocab_id,
54                                             absl::string_view* result) const {
55   if (vocab_id >= vocab_.size() || vocab_id < 0) {
56     return false;
57   }
58   *result = vocab_[vocab_id];
59   return true;
60 }
61 
TokenizeSingleToken(const std::string & token)62 TokenizerResult BertTokenizer::TokenizeSingleToken(const std::string& token) {
63   std::vector<std::string> tokens = {token};
64   return BertTokenizer::Tokenize(tokens);
65 }
66 
Tokenize(const std::string & input)67 TokenizerResult BertTokenizer::Tokenize(const std::string& input) {
68   std::vector<std::string> tokens = PreTokenize(input);
69   return BertTokenizer::Tokenize(tokens);
70 }
71 
Tokenize(const std::vector<std::string> & tokens)72 TokenizerResult BertTokenizer::Tokenize(
73     const std::vector<std::string>& tokens) {
74   WordpieceTokenizerResult result;
75   std::vector<std::string>& subwords = result.subwords;
76   std::vector<int>& wp_absolute_begin_offset = result.wp_begin_offset;
77   std::vector<int>& wp_absolute_end_offset = result.wp_end_offset;
78 
79   for (int token_index = 0; token_index < tokens.size(); token_index++) {
80     auto& token = tokens[token_index];
81     int num_word_pieces = 0;
82     LookupStatus status = WordpieceTokenize(
83         token, options_.max_bytes_per_token, options_.max_chars_per_subtoken,
84         options_.suffix_indicator, options_.use_unknown_token,
85         options_.unknown_token, options_.split_unknown_chars, &vocab_,
86         &subwords, &wp_absolute_begin_offset, &wp_absolute_end_offset,
87         &num_word_pieces);
88 
89     if (!status.success) {
90       return std::move(result);
91     }
92   }
93 
94   return std::move(result);
95 }
96 
97 // This replicates how the original bert_tokenizer from the tflite-support
98 // library pretokenize text by using regex_split with these default regexes.
99 // It splits the text on spaces, punctuations and chinese characters and
100 // output all the tokens except spaces.
101 // So far, the only difference between this and the original implementation
102 // we are aware of is that the original regexes has 8 ranges of chinese
103 // unicodes. We have all these 8 ranges plus two extra ranges.
PreTokenize(const absl::string_view input)104 std::vector<std::string> BertTokenizer::PreTokenize(
105     const absl::string_view input) {
106   const std::vector<Token> tokens =
107       TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
108   std::vector<std::string> token_texts;
109   std::transform(tokens.begin(), tokens.end(), std::back_inserter(token_texts),
110                  [](Token const& token) { return std::move(token.value); });
111 
112   return token_texts;
113 }
114 
115 }  // namespace libtextclassifier3
116