1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/wordpiece_tokenizer.h"
18 
19 #include "utils/utf8/unicodetext.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 
24 namespace libtextclassifier3 {
25 
26 namespace {
27 
Lookup(int byte_start,int byte_end,const absl::string_view token,const std::string & suffix_indicator,const WordpieceVocab * vocab_map,bool * in_vocab)28 LookupStatus Lookup(int byte_start, int byte_end, const absl::string_view token,
29                     const std::string& suffix_indicator,
30                     const WordpieceVocab* vocab_map, bool* in_vocab) {
31   int byte_len = byte_end - byte_start;
32   absl::string_view substr(token.data() + byte_start, byte_len);
33   std::string lookup_value;
34   if (byte_start > 0) {
35     lookup_value = absl::StrCat(suffix_indicator, substr);
36   } else {
37     // absl::CopyToString
38     lookup_value.assign(substr.begin(), substr.end());
39   }
40   return vocab_map->Contains(lookup_value, in_vocab);
41 }
42 
43 // Sets byte_end to the longest byte sequence which:
44 // 1) is a proper UTF8 sequence
45 // 2) is in the vocab OR if split_unknown_characters is true, is a single
46 //    UTF8 character.
47 // If no match is found, found_match is set to false.
LongestMatchStartingAt(int byte_start,const absl::string_view token,const std::string & suffix_indicator,const int max_chars_per_subtoken,bool split_unknown_characters,const WordpieceVocab * vocab_map,int * byte_end,bool * found_match,bool * match_is_unknown_character)48 LookupStatus LongestMatchStartingAt(
49     int byte_start, const absl::string_view token,
50     const std::string& suffix_indicator, const int max_chars_per_subtoken,
51     bool split_unknown_characters, const WordpieceVocab* vocab_map,
52     int* byte_end, bool* found_match, bool* match_is_unknown_character) {
53   *match_is_unknown_character = false;
54   *found_match = false;
55   const UnicodeText unicode_token =
56       UTF8ToUnicodeText(token.substr(byte_start), /*do_copy=*/false);
57   std::vector<int32_t> byte_ends;
58   int32_t codepoint_offset = byte_start;
59   for (auto it = unicode_token.begin(); it != unicode_token.end(); ++it) {
60     codepoint_offset += it.utf8_length();
61     byte_ends.push_back(codepoint_offset);
62     if (max_chars_per_subtoken > 0 &&
63         byte_ends.size() == max_chars_per_subtoken) {
64       // If the max bytes of a subtoken is known, do not search beyond that
65       // length.
66       break;
67     }
68   }
69   int n = byte_ends.size();
70   for (int i = n - 1; i >= 0; i--) {
71     bool in_vocab;
72     auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator,
73                          vocab_map, &in_vocab);
74     if (!status.success) return status;
75     if (in_vocab) {
76       *byte_end = byte_ends[i];
77       *found_match = true;
78       return LookupStatus::OK();
79     }
80     if (i == 0 && split_unknown_characters) {
81       *byte_end = byte_ends[0];
82       *found_match = true;
83       *match_is_unknown_character = true;
84       return LookupStatus::OK();
85     }
86   }
87   return LookupStatus::OK();
88 }
89 
90 // Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no
91 // token is found.
NoTokenFound(const absl::string_view token,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)92 LookupStatus NoTokenFound(const absl::string_view token, bool use_unknown_token,
93                           const std::string& unknown_token,
94                           std::vector<std::string>* subwords,
95                           std::vector<int>* begin_offset,
96                           std::vector<int>* end_offset, int* num_word_pieces) {
97   begin_offset->push_back(0);
98   if (use_unknown_token) {
99     subwords->push_back(unknown_token);
100     end_offset->push_back(token.length());
101   } else {
102     subwords->emplace_back(token.data(), token.length());
103     end_offset->push_back(token.length());
104   }
105   ++(*num_word_pieces);
106 
107   return LookupStatus::OK();
108 }
109 
110 // When a subword is found, this helper function will add the outputs to
111 // 'subwords', 'begin_offset' and 'end_offset'.
AddWord(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)112 void AddWord(const absl::string_view token, int byte_start, int byte_end,
113              const std::string& suffix_indicator,
114              std::vector<std::string>* subwords, std::vector<int>* begin_offset,
115              std::vector<int>* end_offset) {
116   begin_offset->push_back(byte_start);
117   int len = byte_end - byte_start;
118 
119   if (byte_start > 0) {
120     // Prepend suffix_indicator if the token is within a word.
121     subwords->push_back(::absl::StrCat(
122         suffix_indicator, absl::string_view(token.data() + byte_start, len)));
123   } else {
124     subwords->emplace_back(token.data(), len);
125   }
126   end_offset->push_back(byte_end);
127 }
128 
129 // Adds a single unknown character subword, found when split_unknown_characters
130 // is true.
AddUnknownCharacter(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)131 void AddUnknownCharacter(const absl::string_view token, int byte_start,
132                          int byte_end, const std::string& suffix_indicator,
133                          bool use_unknown_token,
134                          const std::string& unknown_token,
135                          std::vector<std::string>* subwords,
136                          std::vector<int>* begin_offset,
137                          std::vector<int>* end_offset) {
138   begin_offset->push_back(byte_start);
139   end_offset->push_back(byte_end);
140   int len = byte_end - byte_start;
141   if (use_unknown_token) {
142     if (byte_start > 0) {
143       // Prepend suffix_indicator if the character is within a word.
144       subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token));
145     } else {
146       subwords->push_back(unknown_token);
147     }
148   } else {
149     if (byte_start > 0) {
150       // Prepend suffix_indicator if the character is within a word.
151       subwords->push_back(::absl::StrCat(
152           suffix_indicator, absl::string_view(token.data() + byte_start, len)));
153     } else {
154       subwords->emplace_back(token.data(), len);
155     }
156   }
157 }
158 
TokenizeL2RGreedy(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)159 LookupStatus TokenizeL2RGreedy(
160     const absl::string_view token, const int max_bytes_per_token,
161     const int max_chars_per_subtoken, const std::string& suffix_indicator,
162     bool use_unknown_token, const std::string& unknown_token,
163     bool split_unknown_characters, const WordpieceVocab* vocab_map,
164     std::vector<std::string>* subwords, std::vector<int>* begin_offset,
165     std::vector<int>* end_offset, int* num_word_pieces) {
166   std::vector<std::string> candidate_subwords;
167   std::vector<int> candidate_begin_offsets;
168   std::vector<int> candidate_end_offsets;
169   const int token_len = token.length();
170   for (int byte_start = 0; byte_start < token_len;) {
171     int byte_end;
172     bool found_subword;
173     bool match_is_unknown_character;
174     auto status = LongestMatchStartingAt(
175         byte_start, token, suffix_indicator, max_chars_per_subtoken,
176         split_unknown_characters, vocab_map, &byte_end, &found_subword,
177         &match_is_unknown_character);
178     if (!status.success) return status;
179     if (found_subword) {
180       if (match_is_unknown_character) {
181         AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator,
182                             use_unknown_token, unknown_token,
183                             &candidate_subwords, &candidate_begin_offsets,
184                             &candidate_end_offsets);
185       } else {
186         AddWord(token, byte_start, byte_end, suffix_indicator,
187                 &candidate_subwords, &candidate_begin_offsets,
188                 &candidate_end_offsets);
189       }
190       byte_start = byte_end;
191     } else {
192       return NoTokenFound(token, use_unknown_token, unknown_token, subwords,
193                           begin_offset, end_offset, num_word_pieces);
194     }
195   }
196 
197   subwords->insert(subwords->end(), candidate_subwords.begin(),
198                    candidate_subwords.end());
199   begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(),
200                        candidate_begin_offsets.end());
201   end_offset->insert(end_offset->end(), candidate_end_offsets.begin(),
202                      candidate_end_offsets.end());
203   *num_word_pieces += candidate_subwords.size();
204   return LookupStatus::OK();
205 }
206 
207 }  // namespace
208 
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)209 LookupStatus WordpieceTokenize(
210     const absl::string_view token, const int max_bytes_per_token,
211     const int max_chars_per_subtoken, const std::string& suffix_indicator,
212     bool use_unknown_token, const std::string& unknown_token,
213     bool split_unknown_characters, const WordpieceVocab* vocab_map,
214     std::vector<std::string>* subwords, std::vector<int>* begin_offset,
215     std::vector<int>* end_offset, int* num_word_pieces) {
216   int token_len = token.size();
217   if (token_len > max_bytes_per_token) {
218     begin_offset->push_back(0);
219     *num_word_pieces = 1;
220     if (use_unknown_token) {
221       end_offset->push_back(unknown_token.size());
222       subwords->emplace_back(unknown_token);
223     } else {
224       subwords->emplace_back(token);
225       end_offset->push_back(token.size());
226     }
227     return LookupStatus::OK();
228   }
229   return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken,
230                            suffix_indicator, use_unknown_token, unknown_token,
231                            split_unknown_characters, vocab_map, subwords,
232                            begin_offset, end_offset, num_word_pieces);
233 }
234 
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)235 LookupStatus WordpieceTokenize(
236     const absl::string_view token, const int max_bytes_per_token,
237     const std::string& suffix_indicator, bool use_unknown_token,
238     const std::string& unknown_token, const WordpieceVocab* vocab_map,
239     std::vector<std::string>* subwords, std::vector<int>* begin_offset,
240     std::vector<int>* end_offset, int* num_word_pieces) {
241   return WordpieceTokenize(token, max_bytes_per_token,
242                            /* max_chars_per_subtoken= */ 0, suffix_indicator,
243                            use_unknown_token, unknown_token,
244                            /* split_unknown_characters= */ false, vocab_map,
245                            subwords, begin_offset, end_offset, num_word_pieces);
246 }
247 }  // namespace libtextclassifier3
248