1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/dist_diversification.h"
18 
19 #include <algorithm>
20 #include "tensorflow/lite/context.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/model.h"
23 
24 namespace libtextclassifier3 {
25 namespace {
26 
27 // Returns a vector of row indices in a distance matrix.
28 // Indices are increasing and the distance of every selected index to others
29 // is larger than `min_distance`.
30 template <typename DistanceMatrixType>
DiversifyByDistance(const DistanceMatrixType & distance_matrix,const int matrix_size,const float min_distance,const int max_num_results)31 std::vector<int> DiversifyByDistance(const DistanceMatrixType& distance_matrix,
32                                      const int matrix_size,
33                                      const float min_distance,
34                                      const int max_num_results) {
35   std::vector<int> result{0};
36   result.reserve(max_num_results);
37   int index = 1;
38   while (result.size() < max_num_results && index < matrix_size) {
39     for (; index < matrix_size; ++index) {
40       bool too_close = false;
41       for (const int selected_index : result) {
42         if (distance_matrix(index, selected_index) < min_distance) {
43           too_close = true;
44           break;
45         }
46       }
47       if (!too_close) {
48         result.push_back(index);
49         ++index;
50         break;
51       }
52     }
53   }
54   return result;
55 }
56 
57 // Input parameters for the op.
58 enum DistDiversificationInputs {
59   DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX = 0,
60   DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE = 1,
61   DIST_DIVERSIFICATION_INPUT_NUM_RESULTS = 2
62 };
63 
64 // Output parameters for the op.
65 enum DistDiversificationOutputs {
66   DIST_DIVERSIFICATION_OUTPUT_INDICES = 0,
67   DIST_DIVERSIFICATION_OUTPUT_LENGTH = 1,
68 };
69 
CreateSizeArray(const std::initializer_list<int> & sizes)70 TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
71   TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
72   int index = 0;
73   for (const int size : sizes) {
74     array_size->data[index++] = size;
75   }
76   return array_size;
77 }
78 
AllocateOutputIndexes(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus AllocateOutputIndexes(TfLiteContext* context, TfLiteNode* node) {
80   const TfLiteTensor& num_results =
81       context
82           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
83   TfLiteTensor& output_indices =
84       context
85           ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
86   return context->ResizeTensor(context, &output_indices,
87                                CreateSizeArray({num_results.data.i32[0]}));
88 }
89 
Prepare(TfLiteContext * context,TfLiteNode * node)90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91   const TfLiteTensor& num_results =
92       context
93           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
94   if (tflite::IsConstantTensor(&num_results)) {
95     TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
96   } else {
97     TfLiteTensor& output_indices =
98         context
99             ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
100     tflite::SetTensorToDynamic(&output_indices);
101   }
102   TfLiteTensor& output_length =
103       context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
104   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_length,
105                                                    CreateSizeArray({1})));
106   return kTfLiteOk;
107 }
108 
Eval(TfLiteContext * context,TfLiteNode * node)109 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
110   TfLiteTensor& output_indices =
111       context
112           ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
113   if (tflite::IsDynamicTensor(&output_indices)) {
114     TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
115   }
116   const TfLiteTensor& distance_matrix =
117       context->tensors[node->inputs
118                            ->data[DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX]];
119   const int distance_matrix_dim = distance_matrix.dims->data[0];
120   const float min_distance =
121       context
122           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE]]
123           .data.f[0];
124   const int num_results =
125       context
126           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]
127           .data.i32[0];
128   const auto indices = DiversifyByDistance(
129       [&](int row, int col) {
130         return distance_matrix.data.f[row * distance_matrix_dim + col];
131       },
132       distance_matrix_dim, min_distance, num_results);
133   std::copy(indices.begin(), indices.end(), output_indices.data.i32);
134   std::fill_n(output_indices.data.i32 + indices.size(),
135               num_results - indices.size(), -1);
136   TfLiteTensor& output_length =
137       context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
138   *output_length.data.i32 = indices.size();
139   return kTfLiteOk;
140 }
141 
142 }  // namespace
143 }  // namespace libtextclassifier3
144 
145 namespace tflite {
146 namespace ops {
147 namespace custom {
Register_DISTANCE_DIVERSIFICATION()148 TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION() {
149   static TfLiteRegistration r = {nullptr, nullptr, libtextclassifier3::Prepare,
150                                  libtextclassifier3::Eval};
151   return &r;
152 }
153 }  // namespace custom
154 }  // namespace ops
155 }  // namespace tflite
156