1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/encoder_common.h"
18
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/string_util.h"
21
22 namespace libtextclassifier3 {
23
CreateIntArray(const std::initializer_list<int> & values)24 TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
25 TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
26 int index = 0;
27 for (const int size : values) {
28 array_size->data[index++] = size;
29 }
30 return array_size;
31 }
32
CopyValuesToTensorAndPadOrTruncate(const TfLiteTensor & in,const std::vector<int> & encoding_end_offsets,int start_offset,TfLiteContext * context,TfLiteTensor * out)33 TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
34 const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
35 int start_offset, TfLiteContext* context, TfLiteTensor* out) {
36 TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
37 TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
38 const int output_size = out->dims->data[1];
39 int output_offset = 0;
40 for (int value_index = 0;
41 value_index < encoding_end_offsets.size() && output_offset < output_size;
42 ++value_index) {
43 // Calculate how many elements need to be set with this value.
44 // The low bound depends on the offset from the beginning. If this is 0, it
45 // means that this value it truncated.
46 // The upper bound depends on how many elements are in the output tensor.
47 const int from_this_element =
48 std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
49 output_offset),
50 output_size - output_offset);
51 if (from_this_element == 0) {
52 continue;
53 }
54
55 switch (in.type) {
56 case kTfLiteInt32: {
57 std::fill(out->data.i32 + output_offset,
58 out->data.i32 + output_offset + from_this_element,
59 in.data.i32[value_index]);
60 } break;
61 case kTfLiteFloat32: {
62 std::fill(out->data.f + output_offset,
63 out->data.f + output_offset + from_this_element,
64 in.data.f[value_index]);
65 } break;
66 default:
67 context->ReportError(
68 (context), __FILE__ " Not supported attribute type %d", in.type);
69 return kTfLiteError;
70 }
71 output_offset += from_this_element;
72 }
73 // Do final padding.
74 switch (in.type) {
75 case kTfLiteInt32: {
76 const int32_t value =
77 (output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
78 std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
79 value);
80 } break;
81 case kTfLiteFloat32: {
82 const float value =
83 (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
84 std::fill(out->data.f + output_offset, out->data.f + output_size, value);
85 } break;
86 default:
87 break;
88 }
89 return kTfLiteOk;
90 }
91
ResizeOutputTensor(const int max_output_length,TfLiteTensor * tensor,TfLiteContext * context)92 TfLiteStatus ResizeOutputTensor(const int max_output_length,
93 TfLiteTensor* tensor, TfLiteContext* context) {
94 TF_LITE_ENSURE_OK(
95 context, context->ResizeTensor(
96 context, tensor,
97 CreateIntArray({kEncoderBatchSize, max_output_length})));
98 return kTfLiteOk;
99 }
100
CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,const std::vector<int32_t> & data,const int32_t padding_value,TfLiteTensor * output_tensor)101 int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
102 const std::vector<int32_t>& data,
103 const int32_t padding_value,
104 TfLiteTensor* output_tensor) {
105 const int num_skip =
106 std::max(0, static_cast<int>(data.size()) - max_output_length);
107 int output_offset = 0;
108 int32_t* output_buffer = output_tensor->data.i32;
109 for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
110 output_buffer[output_offset] = data[i];
111 }
112
113 // Do padding.
114 for (; output_offset < max_output_length; ++output_offset) {
115 output_buffer[output_offset] = padding_value;
116 }
117
118 // Return number of skipped entries from the beginning.
119 return num_skip;
120 }
121
122 } // namespace libtextclassifier3
123