1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
17 
18 #include <algorithm>
19 #include <tuple>
20 
21 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h"
22 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace custom {
27 namespace sentencepiece {
28 namespace {
29 
30 const char kSpaceSymbol[] = "\xe2\x96\x81";
31 
32 template <typename processing_callback>
process_string(const std::string & input,const std::vector<int> & offsets,const processing_callback & pc)33 std::tuple<std::string, std::vector<int>> process_string(
34     const std::string& input, const std::vector<int>& offsets,
35     const processing_callback& pc) {
36   std::string result_string;
37   result_string.reserve(input.size());
38   std::vector<int> result_offsets;
39   result_offsets.reserve(offsets.size());
40   for (int i = 0, j = 0; i < input.size();) {
41     auto [consumed, new_string] = pc(input.data() + i, input.size() - i);
42     if (consumed == 0) {
43       // Skip the current byte and move forward.
44       result_string.push_back(input[i]);
45       result_offsets.push_back(offsets[j]);
46       i++;
47       j++;
48       continue;
49     }
50     result_string.append(new_string.data(), new_string.length());
51     for (int i = 0; i < new_string.length(); ++i) {
52       result_offsets.push_back(offsets[j]);
53     }
54     j += consumed;
55     i += consumed;
56   }
57   return std::make_tuple(result_string, result_offsets);
58 }
59 
is_whitespace(char c)60 inline char is_whitespace(char c) {
61   return c == ' ' || c == '\t' || c == '\r' || c == '\n';
62 }
63 
remove_extra_whitespaces(const char * data,int len)64 std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
65                                                              int len) {
66   if (len == 0 || !is_whitespace(*data)) {
67     return std::make_tuple(0, utils::string_view(nullptr, 0));
68   }
69   int num_consumed = 1;
70   for (; num_consumed < len && is_whitespace(data[num_consumed]);
71        ++num_consumed) {
72   }
73   return num_consumed > 1
74              ? std::make_tuple(num_consumed, utils::string_view(" ", 1))
75              : std::make_tuple(0, utils::string_view(nullptr, 0));
76 }
77 
find_replacement(const char * data,int len,const DoubleArrayTrie & dat,const flatbuffers::Vector<int8_t> & replacements)78 std::tuple<int, utils::string_view> find_replacement(
79     const char* data, int len, const DoubleArrayTrie& dat,
80     const flatbuffers::Vector<int8_t>& replacements) {
81   const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
82   if (!max_match.empty()) {
83     // Because flatbuffer byte is signed char which is not the same as char,
84     // there is the reinterpret_cast here.
85     const char* replaced_string_ptr =
86         reinterpret_cast<const char*>(replacements.data() + max_match.id);
87     return std::make_tuple(max_match.match_length,
88                            utils::string_view(replaced_string_ptr));
89   }
90   return std::make_tuple(0, utils::string_view(nullptr, 0));
91 }
92 }  // namespace
93 
NormalizeString(const std::string & in_string,const EncoderConfig & config)94 std::tuple<std::string, std::vector<int>> NormalizeString(
95     const std::string& in_string, const EncoderConfig& config) {
96   std::vector<int> output_offsets;
97   std::string result = in_string;
98   output_offsets.reserve(in_string.length());
99   for (int i = 0; i < in_string.length(); ++i) {
100     output_offsets.push_back(i);
101   }
102   if (in_string.empty()) {
103     return std::make_tuple(result, output_offsets);
104   }
105   if (config.add_dummy_prefix()) {
106     result.insert(result.begin(), ' ');
107     output_offsets.insert(output_offsets.begin(), 0);
108   }
109   // Greedely replace normalized_prefixes with normalized_replacements
110   if (config.normalized_prefixes() != nullptr &&
111       config.normalized_replacements() != nullptr) {
112     const DoubleArrayTrie normalized_prefixes_matcher(
113         config.normalized_prefixes()->nodes());
114     const auto norm_replace = [&config, &normalized_prefixes_matcher](
115                                   const char* data, int len) {
116       return find_replacement(data, len, normalized_prefixes_matcher,
117                               *config.normalized_replacements());
118     };
119     std::tie(result, output_offsets) =
120         process_string(result, output_offsets, norm_replace);
121   }
122   if (config.remove_extra_whitespaces()) {
123     std::tie(result, output_offsets) =
124         process_string(result, output_offsets, remove_extra_whitespaces);
125     if (!result.empty() && is_whitespace(result.back())) {
126       result.pop_back();
127       output_offsets.pop_back();
128     }
129   }
130   if (config.escape_whitespaces()) {
131     const auto replace_whitespaces = [](const char* data, int len) {
132       if (len > 0 && is_whitespace(*data)) {
133         return std::make_tuple(1, utils::string_view(kSpaceSymbol));
134       }
135       return std::make_tuple(0, utils::string_view(nullptr, 0));
136     };
137     std::tie(result, output_offsets) =
138         process_string(result, output_offsets, replace_whitespaces);
139   }
140 
141   return std::make_tuple(result, output_offsets);
142 }
143 
EncodeNormalizedString(const std::string & str,const std::vector<int> & offsets,const EncoderConfig & config,bool add_bos,bool add_eos,bool reverse)144 EncoderResult EncodeNormalizedString(const std::string& str,
145                                      const std::vector<int>& offsets,
146                                      const EncoderConfig& config, bool add_bos,
147                                      bool add_eos, bool reverse) {
148   const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
149   const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
150   const int unknown_code = config.unknown_code();
151   const float unknown_penalty = config.unknown_penalty();
152   struct LatticeElement {
153     float score = 0;
154     int code = -1;
155     int prev_position = -1;
156     LatticeElement(float score_, int code_, int prev_position_)
157         : score(score_), code(code_), prev_position(prev_position_) {}
158     LatticeElement() {}
159   };
160   const int length = str.length();
161   std::vector<LatticeElement> lattice(length + 1);
162   for (int i = 0; i < length; ++i) {
163     if (i > 0 && lattice[i].prev_position < 0) {
164       // This state is unreachable.
165       continue;
166     }
167     if (unknown_code >= 0) {
168       // Put unknown code.
169       const float penalized_score = lattice[i].score + unknown_penalty;
170       const int pos = i + 1;
171       LatticeElement& current_element = lattice[pos];
172       if (current_element.prev_position < 0 ||
173           current_element.score < penalized_score) {
174         current_element = LatticeElement(
175             penalized_score, unknown_code,
176             // If the current state is already reached by unknown code, merge
177             // states.
178             lattice[i].code == unknown_code ? lattice[i].prev_position : i);
179       }
180     }
181     auto lattice_update = [&lattice, i,
182                            piece_scores](const DoubleArrayTrie::Match& m) {
183       LatticeElement& target_element = lattice[i + m.match_length];
184       const float score = lattice[i].score + (*piece_scores)[m.id];
185       if (target_element.prev_position < 0 || target_element.score < score) {
186         target_element = LatticeElement(score, m.id, i);
187       }
188     };
189     piece_matcher.IteratePrefixMatches(
190         utils::string_view(str.data() + i, length - i), lattice_update);
191   }
192 
193   EncoderResult result;
194   if (add_eos) {
195     result.codes.push_back(config.end_code());
196     result.offsets.push_back(length);
197   }
198   if (lattice[length].prev_position >= 0) {
199     for (int pos = length; pos > 0;) {
200       auto code = lattice[pos].code;
201       if (code != config.unknown_code()) {
202         code += config.encoding_offset();
203       }
204       result.codes.push_back(code);
205       pos = lattice[pos].prev_position;
206       result.offsets.push_back(offsets[pos]);
207     }
208   }
209   if (add_bos) {
210     result.codes.push_back(config.start_code());
211     result.offsets.push_back(0);
212   }
213   if (!reverse) {
214     std::reverse(result.codes.begin(), result.codes.end());
215     std::reverse(result.offsets.begin(), result.offsets.end());
216   }
217   return result;
218 }
219 
EncodeString(const std::string & string,const void * config_buffer,bool add_bos,bool add_eos,bool reverse)220 EncoderResult EncodeString(const std::string& string, const void* config_buffer,
221                            bool add_bos, bool add_eos, bool reverse) {
222   // Get the config from the buffer.
223   const EncoderConfig* config = GetEncoderConfig(config_buffer);
224   if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
225     EncoderResult result;
226     result.type = EncoderResultType::WRONG_CONFIG;
227     return result;
228   }
229   std::string normalized_string;
230   std::vector<int> offsets;
231   std::tie(normalized_string, offsets) = NormalizeString(string, *config);
232   return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
233                                 add_eos, reverse);
234 }
235 
236 }  // namespace sentencepiece
237 }  // namespace custom
238 }  // namespace ops
239 }  // namespace tflite
240