1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
17
18 #include <algorithm>
19 #include <tuple>
20
21 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h"
22 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace custom {
27 namespace sentencepiece {
28 namespace {
29
30 const char kSpaceSymbol[] = "\xe2\x96\x81";
31
32 template <typename processing_callback>
process_string(const std::string & input,const std::vector<int> & offsets,const processing_callback & pc)33 std::tuple<std::string, std::vector<int>> process_string(
34 const std::string& input, const std::vector<int>& offsets,
35 const processing_callback& pc) {
36 std::string result_string;
37 result_string.reserve(input.size());
38 std::vector<int> result_offsets;
39 result_offsets.reserve(offsets.size());
40 for (int i = 0, j = 0; i < input.size();) {
41 auto [consumed, new_string] = pc(input.data() + i, input.size() - i);
42 if (consumed == 0) {
43 // Skip the current byte and move forward.
44 result_string.push_back(input[i]);
45 result_offsets.push_back(offsets[j]);
46 i++;
47 j++;
48 continue;
49 }
50 result_string.append(new_string.data(), new_string.length());
51 for (int i = 0; i < new_string.length(); ++i) {
52 result_offsets.push_back(offsets[j]);
53 }
54 j += consumed;
55 i += consumed;
56 }
57 return std::make_tuple(result_string, result_offsets);
58 }
59
is_whitespace(char c)60 inline char is_whitespace(char c) {
61 return c == ' ' || c == '\t' || c == '\r' || c == '\n';
62 }
63
remove_extra_whitespaces(const char * data,int len)64 std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
65 int len) {
66 if (len == 0 || !is_whitespace(*data)) {
67 return std::make_tuple(0, utils::string_view(nullptr, 0));
68 }
69 int num_consumed = 1;
70 for (; num_consumed < len && is_whitespace(data[num_consumed]);
71 ++num_consumed) {
72 }
73 return num_consumed > 1
74 ? std::make_tuple(num_consumed, utils::string_view(" ", 1))
75 : std::make_tuple(0, utils::string_view(nullptr, 0));
76 }
77
find_replacement(const char * data,int len,const DoubleArrayTrie & dat,const flatbuffers::Vector<int8_t> & replacements)78 std::tuple<int, utils::string_view> find_replacement(
79 const char* data, int len, const DoubleArrayTrie& dat,
80 const flatbuffers::Vector<int8_t>& replacements) {
81 const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
82 if (!max_match.empty()) {
83 // Because flatbuffer byte is signed char which is not the same as char,
84 // there is the reinterpret_cast here.
85 const char* replaced_string_ptr =
86 reinterpret_cast<const char*>(replacements.data() + max_match.id);
87 return std::make_tuple(max_match.match_length,
88 utils::string_view(replaced_string_ptr));
89 }
90 return std::make_tuple(0, utils::string_view(nullptr, 0));
91 }
92 } // namespace
93
NormalizeString(const std::string & in_string,const EncoderConfig & config)94 std::tuple<std::string, std::vector<int>> NormalizeString(
95 const std::string& in_string, const EncoderConfig& config) {
96 std::vector<int> output_offsets;
97 std::string result = in_string;
98 output_offsets.reserve(in_string.length());
99 for (int i = 0; i < in_string.length(); ++i) {
100 output_offsets.push_back(i);
101 }
102 if (in_string.empty()) {
103 return std::make_tuple(result, output_offsets);
104 }
105 if (config.add_dummy_prefix()) {
106 result.insert(result.begin(), ' ');
107 output_offsets.insert(output_offsets.begin(), 0);
108 }
109 // Greedely replace normalized_prefixes with normalized_replacements
110 if (config.normalized_prefixes() != nullptr &&
111 config.normalized_replacements() != nullptr) {
112 const DoubleArrayTrie normalized_prefixes_matcher(
113 config.normalized_prefixes()->nodes());
114 const auto norm_replace = [&config, &normalized_prefixes_matcher](
115 const char* data, int len) {
116 return find_replacement(data, len, normalized_prefixes_matcher,
117 *config.normalized_replacements());
118 };
119 std::tie(result, output_offsets) =
120 process_string(result, output_offsets, norm_replace);
121 }
122 if (config.remove_extra_whitespaces()) {
123 std::tie(result, output_offsets) =
124 process_string(result, output_offsets, remove_extra_whitespaces);
125 if (!result.empty() && is_whitespace(result.back())) {
126 result.pop_back();
127 output_offsets.pop_back();
128 }
129 }
130 if (config.escape_whitespaces()) {
131 const auto replace_whitespaces = [](const char* data, int len) {
132 if (len > 0 && is_whitespace(*data)) {
133 return std::make_tuple(1, utils::string_view(kSpaceSymbol));
134 }
135 return std::make_tuple(0, utils::string_view(nullptr, 0));
136 };
137 std::tie(result, output_offsets) =
138 process_string(result, output_offsets, replace_whitespaces);
139 }
140
141 return std::make_tuple(result, output_offsets);
142 }
143
EncodeNormalizedString(const std::string & str,const std::vector<int> & offsets,const EncoderConfig & config,bool add_bos,bool add_eos,bool reverse)144 EncoderResult EncodeNormalizedString(const std::string& str,
145 const std::vector<int>& offsets,
146 const EncoderConfig& config, bool add_bos,
147 bool add_eos, bool reverse) {
148 const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
149 const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
150 const int unknown_code = config.unknown_code();
151 const float unknown_penalty = config.unknown_penalty();
152 struct LatticeElement {
153 float score = 0;
154 int code = -1;
155 int prev_position = -1;
156 LatticeElement(float score_, int code_, int prev_position_)
157 : score(score_), code(code_), prev_position(prev_position_) {}
158 LatticeElement() {}
159 };
160 const int length = str.length();
161 std::vector<LatticeElement> lattice(length + 1);
162 for (int i = 0; i < length; ++i) {
163 if (i > 0 && lattice[i].prev_position < 0) {
164 // This state is unreachable.
165 continue;
166 }
167 if (unknown_code >= 0) {
168 // Put unknown code.
169 const float penalized_score = lattice[i].score + unknown_penalty;
170 const int pos = i + 1;
171 LatticeElement& current_element = lattice[pos];
172 if (current_element.prev_position < 0 ||
173 current_element.score < penalized_score) {
174 current_element = LatticeElement(
175 penalized_score, unknown_code,
176 // If the current state is already reached by unknown code, merge
177 // states.
178 lattice[i].code == unknown_code ? lattice[i].prev_position : i);
179 }
180 }
181 auto lattice_update = [&lattice, i,
182 piece_scores](const DoubleArrayTrie::Match& m) {
183 LatticeElement& target_element = lattice[i + m.match_length];
184 const float score = lattice[i].score + (*piece_scores)[m.id];
185 if (target_element.prev_position < 0 || target_element.score < score) {
186 target_element = LatticeElement(score, m.id, i);
187 }
188 };
189 piece_matcher.IteratePrefixMatches(
190 utils::string_view(str.data() + i, length - i), lattice_update);
191 }
192
193 EncoderResult result;
194 if (add_eos) {
195 result.codes.push_back(config.end_code());
196 result.offsets.push_back(length);
197 }
198 if (lattice[length].prev_position >= 0) {
199 for (int pos = length; pos > 0;) {
200 auto code = lattice[pos].code;
201 if (code != config.unknown_code()) {
202 code += config.encoding_offset();
203 }
204 result.codes.push_back(code);
205 pos = lattice[pos].prev_position;
206 result.offsets.push_back(offsets[pos]);
207 }
208 }
209 if (add_bos) {
210 result.codes.push_back(config.start_code());
211 result.offsets.push_back(0);
212 }
213 if (!reverse) {
214 std::reverse(result.codes.begin(), result.codes.end());
215 std::reverse(result.offsets.begin(), result.offsets.end());
216 }
217 return result;
218 }
219
EncodeString(const std::string & string,const void * config_buffer,bool add_bos,bool add_eos,bool reverse)220 EncoderResult EncodeString(const std::string& string, const void* config_buffer,
221 bool add_bos, bool add_eos, bool reverse) {
222 // Get the config from the buffer.
223 const EncoderConfig* config = GetEncoderConfig(config_buffer);
224 if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
225 EncoderResult result;
226 result.type = EncoderResultType::WRONG_CONFIG;
227 return result;
228 }
229 std::string normalized_string;
230 std::vector<int> offsets;
231 std::tie(normalized_string, offsets) = NormalizeString(string, *config);
232 return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
233 add_eos, reverse);
234 }
235
236 } // namespace sentencepiece
237 } // namespace custom
238 } // namespace ops
239 } // namespace tflite
240