1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/non_max_suppression.h"
16 
17 #include <initializer_list>
18 
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace non_max_suppression {
28 
29 // Boxes in format [y1, x1, y2, x2]. Shape: [num_boxes, 4]
30 // Type: Float.
31 constexpr int kInputTensorBoxes = 0;
32 // Shape: [num_boxes]
33 // Type: Float.
34 constexpr int kInputTensorScores = 1;
35 // Max number of boxes to output. Actual output can be smaller.
36 // The output tensors (indices/scores) are of this length.
37 // Type: Int32.
38 constexpr int kInputTensorMaxOutputSize = 2;
39 // Type: Float.
40 constexpr int kInputTensorIouThreshold = 3;
41 // Type: Float.
42 constexpr int kInputTensorScoreThreshold = 4;
43 // Only applies to NON_MAX_SUPPRESSION_V5.
44 // Type: Float.
45 constexpr int kInputTensorSigma = 5;
46 
47 // Indices of selected boxes. Shape: [num_selected_indices]
48 // Type: Int32.
49 constexpr int kNMSOutputTensorSelectedIndices = 0;
50 // Type: Int32.
51 constexpr int kNMSOutputTensorNumSelectedIndices = 1;
52 
53 // Indices of selected boxes. Shape: [num_selected_indices]
54 // Type: Int32.
55 constexpr int kSoftNMSOutputTensorSelectedIndices = 0;
56 // Scores of selected boxes. Shape: [num_selected_indices]
57 // Type: Float.
58 constexpr int kSoftNMSOutputTensorSelectedScores = 1;
59 // Type: Int32.
60 constexpr int kSoftNMSOutputTensorNumSelectedIndices = 2;
61 
SetTensorSizes(TfLiteContext * context,TfLiteTensor * tensor,std::initializer_list<int> values)62 TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
63                             std::initializer_list<int> values) {
64   TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
65   int index = 0;
66   for (const auto& v : values) {
67     size->data[index++] = v;
68   }
69   return context->ResizeTensor(context, tensor, size);
70 }
71 
Prepare(TfLiteContext * context,TfLiteNode * node)72 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
73   const int num_inputs = NumInputs(node);
74   const bool is_soft_nms = num_inputs == 6;
75   if (num_inputs != 5 && num_inputs != 6) {
76     context->ReportError(context, "Found NMS op with invalid num inputs: %d",
77                          NumInputs(node));
78     return kTfLiteError;
79   }
80 
81   // Boxes & Scores.
82   const TfLiteTensor* input_boxes;
83   TF_LITE_ENSURE_OK(
84       context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
85   TF_LITE_ENSURE_EQ(context, input_boxes->type, kTfLiteFloat32);
86   TF_LITE_ENSURE_EQ(context, NumDimensions(input_boxes), 2);
87   TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_boxes, 1), 4);
88   const int num_boxes = SizeOfDimension(input_boxes, 0);
89   const TfLiteTensor* input_scores;
90   TF_LITE_ENSURE_OK(
91       context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
92   TF_LITE_ENSURE_EQ(context, input_scores->type, kTfLiteFloat32);
93   TF_LITE_ENSURE_EQ(context, NumDimensions(input_scores), 1);
94   TF_LITE_ENSURE_EQ(context, num_boxes, SizeOfDimension(input_scores, 0));
95 
96   // Max output size.
97   const TfLiteTensor* input_max_output_size;
98   TF_LITE_ENSURE_OK(context,
99                     GetInputSafe(context, node, kInputTensorMaxOutputSize,
100                                  &input_max_output_size));
101   TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
102   TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
103   const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
104   int max_output_size_value = 0;
105   if (is_max_output_size_const) {
106     max_output_size_value = *GetTensorData<int>(input_max_output_size);
107     TF_LITE_ENSURE(context, (max_output_size_value >= 0));
108   }
109 
110   // IoU & Score thresholds.
111   const TfLiteTensor* input_iou_threshold;
112   TF_LITE_ENSURE_OK(context,
113                     GetInputSafe(context, node, kInputTensorIouThreshold,
114                                  &input_iou_threshold));
115   TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
116   TF_LITE_ENSURE_EQ(context, NumDimensions(input_iou_threshold), 0);
117   const TfLiteTensor* input_score_threshold;
118   TF_LITE_ENSURE_OK(context,
119                     GetInputSafe(context, node, kInputTensorScoreThreshold,
120                                  &input_score_threshold));
121   TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
122   TF_LITE_ENSURE_EQ(context, NumDimensions(input_score_threshold), 0);
123 
124   if (is_soft_nms) {
125     const TfLiteTensor* input_sigma;
126     TF_LITE_ENSURE_OK(
127         context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
128     TF_LITE_ENSURE_EQ(context, input_sigma->type, kTfLiteFloat32);
129     TF_LITE_ENSURE_EQ(context, NumDimensions(input_sigma), 0);
130 
131     TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3);
132     TfLiteTensor* output_selected_indices;
133     TF_LITE_ENSURE_OK(
134         context,
135         GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
136                       &output_selected_indices));
137     output_selected_indices->type = kTfLiteInt32;
138     TfLiteTensor* output_selected_scores;
139     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
140                                              kSoftNMSOutputTensorSelectedScores,
141                                              &output_selected_scores));
142     output_selected_scores->type = kTfLiteFloat32;
143     TfLiteTensor* output_num_selected_indices;
144     TF_LITE_ENSURE_OK(
145         context,
146         GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
147                       &output_num_selected_indices));
148     output_num_selected_indices->type = kTfLiteInt32;
149     SetTensorSizes(context, output_num_selected_indices, {});
150 
151     if (is_max_output_size_const) {
152       SetTensorSizes(context, output_selected_indices, {max_output_size_value});
153       SetTensorSizes(context, output_selected_scores, {max_output_size_value});
154     } else {
155       SetTensorToDynamic(output_selected_indices);
156       SetTensorToDynamic(output_selected_scores);
157     }
158   } else {
159     TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
160     TfLiteTensor* output_selected_indices;
161     TF_LITE_ENSURE_OK(
162         context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
163                                &output_selected_indices));
164     output_selected_indices->type = kTfLiteInt32;
165     TfLiteTensor* output_num_selected_indices;
166     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
167                                              kNMSOutputTensorNumSelectedIndices,
168                                              &output_num_selected_indices));
169     output_num_selected_indices->type = kTfLiteInt32;
170     SetTensorSizes(context, output_num_selected_indices, {});
171 
172     if (is_max_output_size_const) {
173       SetTensorSizes(context, output_selected_indices, {max_output_size_value});
174     } else {
175       SetTensorToDynamic(output_selected_indices);
176     }
177   }
178 
179   return kTfLiteOk;
180 }
181 
182 // If num_selected_indices < max_output_size, the output tensor can contain
183 // garbage values initially present in memory. This causes segfault in
184 // downstream ops such as GATHER, since one of the outputs denotes indices and
185 // int garbage values can be pretty large. This method zeroes-out the remaining
186 // values.
187 // NOTE: We ensure memory being reset is valid, by setting pertinent output
188 // tensors to max_output_size length in Prepare.
ResetUnusedElementsToZeroes(const int max_output_size,const int num_selected_indices,int * selected_indices,float * selected_scores)189 void ResetUnusedElementsToZeroes(const int max_output_size,
190                                  const int num_selected_indices,
191                                  int* selected_indices,
192                                  float* selected_scores) {
193   for (int i = num_selected_indices; i < max_output_size; ++i) {
194     selected_indices[i] = 0;
195     if (selected_scores) {
196       selected_scores[i] = 0.0;
197     }
198   }
199 }
200 
Eval(TfLiteContext * context,TfLiteNode * node)201 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
202   const bool is_soft_nms = NumInputs(node) == 6;
203 
204   const TfLiteTensor* input_boxes;
205   TF_LITE_ENSURE_OK(
206       context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
207   const int num_boxes = SizeOfDimension(input_boxes, 0);
208   const TfLiteTensor* input_scores;
209   TF_LITE_ENSURE_OK(
210       context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
211   const TfLiteTensor* input_max_output_size;
212   TF_LITE_ENSURE_OK(context,
213                     GetInputSafe(context, node, kInputTensorMaxOutputSize,
214                                  &input_max_output_size));
215   const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
216   TF_LITE_ENSURE(context, (max_output_size_value >= 0));
217   const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
218   const TfLiteTensor* input_iou_threshold;
219   TF_LITE_ENSURE_OK(context,
220                     GetInputSafe(context, node, kInputTensorIouThreshold,
221                                  &input_iou_threshold));
222   const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
223   const TfLiteTensor* input_score_threshold;
224   TF_LITE_ENSURE_OK(context,
225                     GetInputSafe(context, node, kInputTensorScoreThreshold,
226                                  &input_score_threshold));
227   const float score_threshold = *GetTensorData<float>(input_score_threshold);
228 
229   TfLiteTensor* output_selected_indices = nullptr;
230   TfLiteTensor* output_selected_scores = nullptr;
231   TfLiteTensor* output_num_selected_indices = nullptr;
232 
233   if (is_soft_nms) {
234     const TfLiteTensor* input_sigma;
235     TF_LITE_ENSURE_OK(
236         context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
237     const float soft_nms_sigma = *GetTensorData<float>(input_sigma);
238     if (soft_nms_sigma < 0) {
239       context->ReportError(context, "Invalid sigma value for soft NMS: %f",
240                            soft_nms_sigma);
241       return kTfLiteError;
242     }
243 
244     TF_LITE_ENSURE_OK(
245         context,
246         GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
247                       &output_selected_indices));
248     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
249                                              kSoftNMSOutputTensorSelectedScores,
250                                              &output_selected_scores));
251     TF_LITE_ENSURE_OK(
252         context,
253         GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
254                       &output_num_selected_indices));
255     if (!is_max_output_size_const) {
256       SetTensorSizes(context, output_selected_indices, {max_output_size_value});
257       SetTensorSizes(context, output_selected_scores, {max_output_size_value});
258     }
259     reference_ops::NonMaxSuppression(
260         input_boxes->data.f, num_boxes, input_scores->data.f,
261         max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma,
262         output_selected_indices->data.i32, output_selected_scores->data.f,
263         output_num_selected_indices->data.i32);
264     ResetUnusedElementsToZeroes(
265         max_output_size_value, *output_num_selected_indices->data.i32,
266         output_selected_indices->data.i32, output_selected_scores->data.f);
267   } else {
268     TF_LITE_ENSURE_OK(
269         context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
270                                &output_selected_indices));
271     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
272                                              kNMSOutputTensorNumSelectedIndices,
273                                              &output_num_selected_indices));
274     if (!is_max_output_size_const) {
275       SetTensorSizes(context, output_selected_indices, {max_output_size_value});
276     }
277     reference_ops::NonMaxSuppression(
278         input_boxes->data.f, num_boxes, input_scores->data.f,
279         max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0,
280         output_selected_indices->data.i32, /**selected_scores=**/ nullptr,
281         output_num_selected_indices->data.i32);
282     ResetUnusedElementsToZeroes(max_output_size_value,
283                                 *output_num_selected_indices->data.i32,
284                                 output_selected_indices->data.i32, nullptr);
285   }
286 
287   return kTfLiteOk;
288 }
289 }  // namespace non_max_suppression
290 
Register_NON_MAX_SUPPRESSION_V4()291 TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V4() {
292   static TfLiteRegistration r = {nullptr, nullptr, non_max_suppression::Prepare,
293                                  non_max_suppression::Eval};
294   return &r;
295 }
296 
Register_NON_MAX_SUPPRESSION_V5()297 TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V5() {
298   static TfLiteRegistration r = {nullptr, nullptr, non_max_suppression::Prepare,
299                                  non_max_suppression::Eval};
300   return &r;
301 }
302 
303 }  // namespace builtin
304 }  // namespace ops
305 }  // namespace tflite
306