1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <vector>
19 
20 #include "utils/base/logging.h"
21 #include "utils/sentencepiece/double_array_trie.h"
22 #include "utils/sentencepiece/encoder.h"
23 #include "utils/sentencepiece/normalizer.h"
24 #include "utils/sentencepiece/sorted_strings_table.h"
25 #include "utils/strings/stringpiece.h"
26 #include "utils/tflite/encoder_common.h"
27 #include "utils/tflite/text_encoder.h"
28 #include "utils/tflite/text_encoder_config_generated.h"
29 #include "flatbuffers/flatbuffers.h"
30 #include "flatbuffers/flexbuffers.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/model.h"
33 #include "tensorflow/lite/string_util.h"
34 
35 namespace libtextclassifier3 {
36 namespace {
37 
38 struct TextEncoderOp {
39   std::unique_ptr<SentencePieceNormalizer> normalizer;
40   std::unique_ptr<Encoder> encoder;
41   std::unique_ptr<SentencePieceMatcher> matcher;
42 };
43 
44 // Input parameters for the op.
45 // The conversation message as a (1, conversation length) string tensor.
46 constexpr const int kInputTexts = 0;
47 
48 // The number of messages, the conversation length, int scalar.
49 constexpr const int kInputNumInputs = 1;
50 
51 // Maximum output length of the encoding, int scalar.
52 constexpr const int kInputMaxLength = 2;
53 
54 // Additional attributes to align to the sentence pieces, e.g. user ids per
55 // message.
56 constexpr const int kInputAttr = 3;
57 
58 // Output parameters for the op.
59 // The text sentence piece encodings as ids, (1, max output length) int tensor.
60 constexpr const int kOutputEncoded = 0;
61 
62 // Relative position of each sentence piece in the input text,
63 // (1, max output length) int tensor.
64 constexpr const int kOutputPosition = 1;
65 
66 // Output length after trimming to the maximum output length specified.
67 // int scalar.
68 constexpr const int kOutputLengths = 2;
69 
70 // Padded and sentence piece aligned provided attributes, e.g. user id per
71 // sentence piece.
72 constexpr const int kOutputAttr = 3;
73 
74 const char kTextEncoderConfigAttr[] = "text_encoder_config";
75 
76 // Initializes text encoder object from serialized options:
77 //   The options are a flexbuffers attribute map that contain the op config
78 //   with the key `text_encoder_config` as `TextEncoderConfig`.
Initialize(TfLiteContext * context,const char * buffer,size_t length)79 void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
80   const flexbuffers::Map& attr_map =
81       flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
82           .AsMap();
83   const flexbuffers::Blob serialized_config =
84       attr_map[kTextEncoderConfigAttr].AsBlob();
85   const TextEncoderConfig* config =
86       flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data());
87 
88   std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp());
89 
90   // Create normalizer from options.
91   const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
92       config->normalization_charsmap()->Data());
93   const int charsmap_trie_nodes_length =
94       config->normalization_charsmap()->Length() / sizeof(TrieNode);
95   encoder_op->normalizer.reset(new SentencePieceNormalizer(
96       DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
97       StringPiece(config->normalization_charsmap_values()->data(),
98                   config->normalization_charsmap_values()->size()),
99       config->add_dummy_prefix(), config->remove_extra_whitespaces(),
100       config->escape_whitespaces()));
101 
102   const int num_pieces = config->pieces_scores()->Length();
103 
104   switch (config->matcher_type()) {
105     case SentencePieceMatcherType_MAPPED_TRIE: {
106       const TrieNode* pieces_trie_nodes =
107           reinterpret_cast<const TrieNode*>(config->pieces()->Data());
108       const int pieces_trie_nodes_length =
109           config->pieces()->Length() / sizeof(TrieNode);
110       encoder_op->matcher.reset(
111           new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
112       break;
113     }
114     case SentencePieceMatcherType_SORTED_STRING_TABLE: {
115       encoder_op->matcher.reset(new SortedStringsTable(
116           num_pieces, config->pieces_offsets()->data(),
117           StringPiece(config->pieces()->data(), config->pieces()->Length())));
118       break;
119     }
120     default: {
121       TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
122       return nullptr;
123     }
124   }
125   encoder_op->encoder.reset(new Encoder(
126       encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
127       config->start_code(), config->end_code(), config->encoding_offset(),
128       config->unknown_code(), config->unknown_score()));
129   return encoder_op.release();
130 }
131 
Free(TfLiteContext * context,void * buffer)132 void Free(TfLiteContext* context, void* buffer) {
133   delete reinterpret_cast<TextEncoderOp*>(buffer);
134 }
135 
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)136 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
137                                  int max_output_length) {
138   TF_LITE_ENSURE_OK(
139       context,
140       ResizeOutputTensor(max_output_length,
141                          &context->tensors[node->outputs->data[kOutputEncoded]],
142                          context));
143 
144   TF_LITE_ENSURE_OK(
145       context,
146       ResizeOutputTensor(
147           max_output_length,
148           &context->tensors[node->outputs->data[kOutputPosition]], context));
149 
150   const int num_output_attrs = node->outputs->size - kOutputAttr;
151   for (int i = 0; i < num_output_attrs; ++i) {
152     TF_LITE_ENSURE_OK(
153         context,
154         ResizeOutputTensor(
155             max_output_length,
156             &context->tensors[node->outputs->data[kOutputAttr + i]], context));
157   }
158   return kTfLiteOk;
159 }
160 
Prepare(TfLiteContext * context,TfLiteNode * node)161 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
162   // Check that the batch dimension is kBatchSize.
163   const TfLiteTensor& input_text =
164       context->tensors[node->inputs->data[kInputTexts]];
165   TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
166   TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
167 
168   TfLiteTensor& output_lengths =
169       context->tensors[node->outputs->data[kOutputLengths]];
170   TfLiteTensor& output_encoded =
171       context->tensors[node->outputs->data[kOutputEncoded]];
172   TfLiteTensor& output_positions =
173       context->tensors[node->outputs->data[kOutputPosition]];
174 
175   TF_LITE_ENSURE_OK(context,
176                     context->ResizeTensor(context, &output_lengths,
177                                           CreateIntArray({kEncoderBatchSize})));
178 
179   // Check that there are enough outputs for attributes.
180   const int num_output_attrs = node->outputs->size - kOutputAttr;
181   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
182 
183   // Copy attribute types from input to output tensors.
184   for (int i = 0; i < num_output_attrs; ++i) {
185     TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
186     TfLiteTensor& output =
187         context->tensors[node->outputs->data[kOutputAttr + i]];
188     output.type = input.type;
189   }
190 
191   const TfLiteTensor& output_length =
192       context->tensors[node->inputs->data[kInputMaxLength]];
193 
194   if (tflite::IsConstantTensor(&output_length)) {
195     return ResizeOutputTensors(context, node, output_length.data.i64[0]);
196   } else {
197     tflite::SetTensorToDynamic(&output_encoded);
198     tflite::SetTensorToDynamic(&output_positions);
199     for (int i = 0; i < num_output_attrs; ++i) {
200       TfLiteTensor& output_attr =
201           context->tensors[node->outputs->data[kOutputAttr + i]];
202       tflite::SetTensorToDynamic(&output_attr);
203     }
204   }
205 
206   return kTfLiteOk;
207 }
208 
Eval(TfLiteContext * context,TfLiteNode * node)209 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
210   if (node->user_data == nullptr) {
211     return kTfLiteError;
212   }
213   const TextEncoderOp* encoder_op =
214       reinterpret_cast<TextEncoderOp*>(node->user_data);
215   const TfLiteTensor& input_text =
216       context->tensors[node->inputs->data[kInputTexts]];
217   const int num_strings = tflite::GetStringCount(&input_text);
218   // Check that the number of strings matches the length parameter.
219   const int num_strings_param =
220       context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
221   TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param);
222 
223   TfLiteTensor& output_encoded =
224       context->tensors[node->outputs->data[kOutputEncoded]];
225   if (tflite::IsDynamicTensor(&output_encoded)) {
226     const TfLiteTensor& output_length =
227         context->tensors[node->inputs->data[kInputMaxLength]];
228     TF_LITE_ENSURE_OK(
229         context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
230   }
231   TfLiteTensor& output_positions =
232       context->tensors[node->outputs->data[kOutputPosition]];
233 
234   std::vector<int> encoded_total;
235   std::vector<int> encoded_offsets;
236   std::vector<int> encoded_positions;
237   encoded_offsets.reserve(num_strings);
238   const int max_output_length = output_encoded.dims->data[1];
239   const int max_encoded_position = max_output_length;
240 
241   for (int i = 0; i < num_strings; ++i) {
242     const auto& strref = tflite::GetString(&input_text, i);
243     std::string normalized;
244     TF_LITE_ENSURE(context,
245                    encoder_op->normalizer->Normalize(
246                        StringPiece(strref.str, strref.len), &normalized));
247     std::vector<int> encoded;
248     TF_LITE_ENSURE(context, encoder_op->encoder->Encode(normalized, &encoded));
249     encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
250     encoded_offsets.push_back(encoded_total.size());
251     for (int i = 0; i < encoded.size(); i++) {
252       encoded_positions.push_back(std::min(i, max_encoded_position - 1));
253     }
254   }
255 
256   const int num_skip = CopyDataToTensorAndPadOrTruncate(
257       max_output_length, encoded_total,
258       /*padding_value=*/encoded_total.back(), &output_encoded);
259   TfLiteTensor& output_lengths =
260       context->tensors[node->outputs->data[kOutputLengths]];
261   output_lengths.data.i32[0] = encoded_total.size() - num_skip;
262   CopyDataToTensorAndPadOrTruncate(max_output_length, encoded_positions,
263                                    /*padding_value=*/max_encoded_position,
264                                    &output_positions);
265 
266   // Process attributes, all checks of sizes and types are done in Prepare.
267   const int num_output_attrs = node->outputs->size - kOutputAttr;
268   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
269   for (int i = 0; i < num_output_attrs; ++i) {
270     TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
271         context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
272         num_skip, context,
273         &context->tensors[node->outputs->data[kOutputAttr + i]]);
274     if (attr_status != kTfLiteOk) {
275       return attr_status;
276     }
277   }
278 
279   return kTfLiteOk;
280 }
281 
282 }  // namespace
283 }  // namespace libtextclassifier3
284 
285 namespace tflite {
286 namespace ops {
287 namespace custom {
288 
Register_TEXT_ENCODER()289 TfLiteRegistration* Register_TEXT_ENCODER() {
290   static TfLiteRegistration registration = {
291       libtextclassifier3::Initialize, libtextclassifier3::Free,
292       libtextclassifier3::Prepare, libtextclassifier3::Eval};
293   return &registration;
294 }
295 
296 }  // namespace custom
297 }  // namespace ops
298 }  // namespace tflite
299