1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/dist_diversification.h"
18
19 #include <algorithm>
20 #include "tensorflow/lite/context.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/model.h"
23
24 namespace libtextclassifier3 {
25 namespace {
26
27 // Returns a vector of row indices in a distance matrix.
28 // Indices are increasing and the distance of every selected index to others
29 // is larger than `min_distance`.
30 template <typename DistanceMatrixType>
DiversifyByDistance(const DistanceMatrixType & distance_matrix,const int matrix_size,const float min_distance,const int max_num_results)31 std::vector<int> DiversifyByDistance(const DistanceMatrixType& distance_matrix,
32 const int matrix_size,
33 const float min_distance,
34 const int max_num_results) {
35 std::vector<int> result{0};
36 result.reserve(max_num_results);
37 int index = 1;
38 while (result.size() < max_num_results && index < matrix_size) {
39 for (; index < matrix_size; ++index) {
40 bool too_close = false;
41 for (const int selected_index : result) {
42 if (distance_matrix(index, selected_index) < min_distance) {
43 too_close = true;
44 break;
45 }
46 }
47 if (!too_close) {
48 result.push_back(index);
49 ++index;
50 break;
51 }
52 }
53 }
54 return result;
55 }
56
57 // Input parameters for the op.
58 enum DistDiversificationInputs {
59 DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX = 0,
60 DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE = 1,
61 DIST_DIVERSIFICATION_INPUT_NUM_RESULTS = 2
62 };
63
64 // Output parameters for the op.
65 enum DistDiversificationOutputs {
66 DIST_DIVERSIFICATION_OUTPUT_INDICES = 0,
67 DIST_DIVERSIFICATION_OUTPUT_LENGTH = 1,
68 };
69
CreateSizeArray(const std::initializer_list<int> & sizes)70 TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
71 TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
72 int index = 0;
73 for (const int size : sizes) {
74 array_size->data[index++] = size;
75 }
76 return array_size;
77 }
78
AllocateOutputIndexes(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus AllocateOutputIndexes(TfLiteContext* context, TfLiteNode* node) {
80 const TfLiteTensor& num_results =
81 context
82 ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
83 TfLiteTensor& output_indices =
84 context
85 ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
86 return context->ResizeTensor(context, &output_indices,
87 CreateSizeArray({num_results.data.i32[0]}));
88 }
89
Prepare(TfLiteContext * context,TfLiteNode * node)90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91 const TfLiteTensor& num_results =
92 context
93 ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
94 if (tflite::IsConstantTensor(&num_results)) {
95 TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
96 } else {
97 TfLiteTensor& output_indices =
98 context
99 ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
100 tflite::SetTensorToDynamic(&output_indices);
101 }
102 TfLiteTensor& output_length =
103 context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
104 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_length,
105 CreateSizeArray({1})));
106 return kTfLiteOk;
107 }
108
Eval(TfLiteContext * context,TfLiteNode * node)109 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
110 TfLiteTensor& output_indices =
111 context
112 ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
113 if (tflite::IsDynamicTensor(&output_indices)) {
114 TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
115 }
116 const TfLiteTensor& distance_matrix =
117 context->tensors[node->inputs
118 ->data[DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX]];
119 const int distance_matrix_dim = distance_matrix.dims->data[0];
120 const float min_distance =
121 context
122 ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE]]
123 .data.f[0];
124 const int num_results =
125 context
126 ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]
127 .data.i32[0];
128 const auto indices = DiversifyByDistance(
129 [&](int row, int col) {
130 return distance_matrix.data.f[row * distance_matrix_dim + col];
131 },
132 distance_matrix_dim, min_distance, num_results);
133 std::copy(indices.begin(), indices.end(), output_indices.data.i32);
134 std::fill_n(output_indices.data.i32 + indices.size(),
135 num_results - indices.size(), -1);
136 TfLiteTensor& output_length =
137 context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
138 *output_length.data.i32 = indices.size();
139 return kTfLiteOk;
140 }
141
142 } // namespace
143 } // namespace libtextclassifier3
144
145 namespace tflite {
146 namespace ops {
147 namespace custom {
Register_DISTANCE_DIVERSIFICATION()148 TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION() {
149 static TfLiteRegistration r = {nullptr, nullptr, libtextclassifier3::Prepare,
150 libtextclassifier3::Eval};
151 return &r;
152 }
153 } // namespace custom
154 } // namespace ops
155 } // namespace tflite
156