1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/image_ops.cc
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/non_max_suppression_op.h"
21
22 #include <functional>
23 #include <queue>
24 #include <vector>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/gtl/stl_util.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace tensorflow {
38 namespace {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41
CheckScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)42 static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
43 const Tensor& scores) {
44 // The shape of 'scores' is [num_boxes]
45 OP_REQUIRES(context, scores.dims() == 1,
46 errors::InvalidArgument("scores must be 1-D",
47 scores.shape().DebugString()));
48 OP_REQUIRES(context, scores.dim_size(0) == num_boxes,
49 errors::InvalidArgument("scores has incompatible shape"));
50 }
51
ParseAndCheckOverlapSizes(OpKernelContext * context,const Tensor & overlaps,int * num_boxes)52 static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
53 const Tensor& overlaps,
54 int* num_boxes) {
55 // the shape of 'overlaps' is [num_boxes, num_boxes]
56 OP_REQUIRES(context, overlaps.dims() == 2,
57 errors::InvalidArgument("overlaps must be 2-D",
58 overlaps.shape().DebugString()));
59
60 *num_boxes = overlaps.dim_size(0);
61 OP_REQUIRES(context, overlaps.dim_size(1) == *num_boxes,
62 errors::InvalidArgument("overlaps must be square",
63 overlaps.shape().DebugString()));
64 }
65
ParseAndCheckBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes)66 static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
67 const Tensor& boxes, int* num_boxes) {
68 // The shape of 'boxes' is [num_boxes, 4]
69 OP_REQUIRES(context, boxes.dims() == 2,
70 errors::InvalidArgument("boxes must be 2-D",
71 boxes.shape().DebugString()));
72 *num_boxes = boxes.dim_size(0);
73 OP_REQUIRES(context, boxes.dim_size(1) == 4,
74 errors::InvalidArgument("boxes must have 4 columns"));
75 }
76
CheckCombinedNMSScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)77 static inline void CheckCombinedNMSScoreSizes(OpKernelContext* context,
78 int num_boxes,
79 const Tensor& scores) {
80 // The shape of 'scores' is [batch_size, num_boxes, num_classes]
81 OP_REQUIRES(context, scores.dims() == 3,
82 errors::InvalidArgument("scores must be 3-D",
83 scores.shape().DebugString()));
84 OP_REQUIRES(context, scores.dim_size(1) == num_boxes,
85 errors::InvalidArgument("scores has incompatible shape"));
86 }
87
ParseAndCheckCombinedNMSBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes,const int num_classes)88 static inline void ParseAndCheckCombinedNMSBoxSizes(OpKernelContext* context,
89 const Tensor& boxes,
90 int* num_boxes,
91 const int num_classes) {
92 // The shape of 'boxes' is [batch_size, num_boxes, q, 4]
93 OP_REQUIRES(context, boxes.dims() == 4,
94 errors::InvalidArgument("boxes must be 4-D",
95 boxes.shape().DebugString()));
96
97 bool box_check = boxes.dim_size(2) == 1 || boxes.dim_size(2) == num_classes;
98 OP_REQUIRES(context, box_check,
99 errors::InvalidArgument(
100 "third dimension of boxes must be either 1 or num classes"));
101 *num_boxes = boxes.dim_size(1);
102 OP_REQUIRES(context, boxes.dim_size(3) == 4,
103 errors::InvalidArgument("boxes must have 4 columns"));
104 }
105 // Return intersection-over-union overlap between boxes i and j
106 template <typename T>
IOUGreaterThanThreshold(typename TTypes<T,2>::ConstTensor boxes,int i,int j,T iou_threshold)107 static inline bool IOUGreaterThanThreshold(
108 typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) {
109 const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2));
110 const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3));
111 const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2));
112 const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3));
113 const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2));
114 const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3));
115 const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2));
116 const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3));
117 const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
118 const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
119 if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0;
120 const T intersection_ymin = std::max<T>(ymin_i, ymin_j);
121 const T intersection_xmin = std::max<T>(xmin_i, xmin_j);
122 const T intersection_ymax = std::min<T>(ymax_i, ymax_j);
123 const T intersection_xmax = std::min<T>(xmax_i, xmax_j);
124 const T intersection_area =
125 std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
126 std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
127 const T iou = intersection_area / (area_i + area_j - intersection_area);
128 return iou > iou_threshold;
129 }
130
OverlapsGreaterThanThreshold(typename TTypes<float,2>::ConstTensor overlaps,int i,int j,float overlap_threshold)131 static inline bool OverlapsGreaterThanThreshold(
132 typename TTypes<float, 2>::ConstTensor overlaps, int i, int j,
133 float overlap_threshold) {
134 return overlaps(i, j) > overlap_threshold;
135 }
136
137 template <typename T>
CreateIOUSuppressCheckFn(const Tensor & boxes,float threshold)138 static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
139 const Tensor& boxes, float threshold) {
140 typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
141 return std::bind(&IOUGreaterThanThreshold<T>, boxes_data,
142 std::placeholders::_1, std::placeholders::_2,
143 static_cast<T>(threshold));
144 }
145
CreateOverlapsSuppressCheckFn(const Tensor & overlaps,float threshold)146 static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
147 const Tensor& overlaps, float threshold) {
148 typename TTypes<float, 2>::ConstTensor overlaps_data =
149 overlaps.tensor<float, 2>();
150 return std::bind(&OverlapsGreaterThanThreshold, overlaps_data,
151 std::placeholders::_1, std::placeholders::_2, threshold);
152 }
153
154 template <typename T>
DoNonMaxSuppressionOp(OpKernelContext * context,const Tensor & scores,int num_boxes,const Tensor & max_output_size,const float score_threshold,const std::function<bool (int,int)> & suppress_check_fn,bool pad_to_max_output_size=false,int * ptr_num_valid_outputs=nullptr)155 void DoNonMaxSuppressionOp(
156 OpKernelContext* context, const Tensor& scores, int num_boxes,
157 const Tensor& max_output_size, const float score_threshold,
158 const std::function<bool(int, int)>& suppress_check_fn,
159 bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
160 const int output_size = max_output_size.scalar<int>()();
161
162 std::vector<T> scores_data(num_boxes);
163 std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
164
165 // Data structure for selection candidate in NMS.
166 struct Candidate {
167 int box_index;
168 T score;
169 };
170
171 auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
172 return bs_i.score < bs_j.score;
173 };
174 std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
175 candidate_priority_queue(cmp);
176 for (int i = 0; i < scores_data.size(); ++i) {
177 if (static_cast<float>(scores_data[i]) > score_threshold) {
178 candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
179 }
180 }
181
182 std::vector<int> selected;
183 std::vector<T> selected_scores;
184 Candidate next_candidate;
185
186 while (selected.size() < output_size && !candidate_priority_queue.empty()) {
187 next_candidate = candidate_priority_queue.top();
188 candidate_priority_queue.pop();
189
190 // Overlapping boxes are likely to have similar scores,
191 // therefore we iterate through the previously selected boxes backwards
192 // in order to see if `next_candidate` should be suppressed.
193 bool should_select = true;
194
195 for (int j = static_cast<int>(selected.size()) - 1; j >= 0; --j) {
196 if (suppress_check_fn(next_candidate.box_index, selected[j])) {
197 should_select = false;
198 break;
199 }
200 }
201
202 if (should_select) {
203 selected.push_back(next_candidate.box_index);
204 selected_scores.push_back(next_candidate.score);
205 }
206 }
207
208 int num_valid_outputs = selected.size();
209 if (pad_to_max_output_size) {
210 selected.resize(output_size, 0);
211 selected_scores.resize(output_size, static_cast<T>(0));
212 }
213 if (ptr_num_valid_outputs) {
214 *ptr_num_valid_outputs = num_valid_outputs;
215 }
216
217 // Allocate output tensors
218 Tensor* output_indices = nullptr;
219 TensorShape output_shape({static_cast<int>(selected.size())});
220 OP_REQUIRES_OK(context,
221 context->allocate_output(0, output_shape, &output_indices));
222 TTypes<int, 1>::Tensor output_indices_data = output_indices->tensor<int, 1>();
223 std::copy_n(selected.begin(), selected.size(), output_indices_data.data());
224 }
225
BatchedNonMaxSuppressionOp(OpKernelContext * context,const Tensor & inp_boxes,const Tensor & inp_scores,int num_boxes,const int max_size_per_class,const int total_size_per_batch,const float score_threshold,const float iou_threshold,bool pad_per_class=false)226 void BatchedNonMaxSuppressionOp(
227 OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
228 int num_boxes, const int max_size_per_class, const int total_size_per_batch,
229 const float score_threshold, const float iou_threshold,
230 bool pad_per_class = false) {
231 int q = inp_boxes.dim_size(2);
232 int num_classes = inp_scores.dim_size(2);
233 const int num_batches = inp_boxes.dim_size(0);
234
235 // Default clip window of [0, 0, 1, 1] if none specified
236 std::vector<float> clip_window{0, 0, 1, 1};
237
238 // [num_batches, per_batch_size * 4]
239 std::vector<std::vector<float>> nmsed_boxes(num_batches);
240 // [num_batches, per_batch_size]
241 std::vector<std::vector<float>> nmsed_scores(num_batches);
242 // [num_batches, per_batch_size]
243 std::vector<std::vector<float>> nmsed_classes(num_batches);
244 // [num_batches]
245 std::vector<int> final_valid_detections;
246
247 int per_batch_size = total_size_per_batch;
248
249 // perform non_max_suppression operation for each batch independently
250 for (int batch = 0; batch < num_batches; ++batch) {
251 // dims of per_batch_boxes [num_boxes, q, 4]
252 Tensor per_batch_boxes = inp_boxes.Slice(batch, batch + 1);
253 // dims of per_batch_scores [num_boxes, num_classes]
254 Tensor per_batch_scores = inp_scores.Slice(batch, batch + 1);
255
256 struct ResultCandidate {
257 int box_index;
258 float score;
259 int class_idx;
260 float box_coord[4];
261 };
262
263 std::vector<ResultCandidate> result_candidate_vec;
264
265 float* scores_data = per_batch_scores.unaligned_flat<float>().data();
266 float* boxes_data = per_batch_boxes.unaligned_flat<float>().data();
267
268 // Iterate through all classes
269 for (int class_idx = 0; class_idx < num_classes; ++class_idx) {
270 std::vector<float> class_scores_data;
271 class_scores_data.reserve(num_boxes);
272 std::vector<float> class_boxes_data;
273 class_boxes_data.reserve(num_boxes * 4);
274
275 for (int box = 0; box < num_boxes; ++box) {
276 // Get the scores per class
277 // class_scores_data dim is [num_boxes].
278 class_scores_data.push_back(scores_data[box * num_classes + class_idx]);
279 for (int cid = 0; cid < 4; ++cid) {
280 if (q > 1) {
281 // Get the boxes per class. class_boxes_data dims is [num_boxes, 4]
282 class_boxes_data.push_back(
283 boxes_data[(box * q + class_idx) * 4 + cid]);
284 } else {
285 class_boxes_data.push_back(boxes_data[box * 4 + cid]);
286 }
287 }
288 }
289
290 // Copy class_boxes_data to a tensor
291 TensorShape boxesShape({num_boxes, 4});
292 Tensor boxes(per_batch_boxes.dtype(), boxesShape);
293 std::copy_n(class_boxes_data.begin(), class_boxes_data.size(),
294 boxes.unaligned_flat<float>().data());
295
296 const int size_per_class = std::min(max_size_per_class, num_boxes);
297 // Do NMS, get the candidate indices of form vector<int>
298 // Data structure for selection candidate in NMS.
299 struct Candidate {
300 int box_index;
301 float score;
302 };
303 auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
304 return bs_i.score > bs_j.score;
305 };
306 std::vector<Candidate> candidate_vector;
307 for (int i = 0; i < class_scores_data.size(); ++i) {
308 if (class_scores_data[i] > score_threshold) {
309 candidate_vector.emplace_back(Candidate({i, class_scores_data[i]}));
310 }
311 }
312
313 std::vector<int> selected;
314 std::vector<float> selected_boxes;
315 Candidate next_candidate;
316
317 std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
318 const Tensor const_boxes = boxes;
319 typename TTypes<float, 2>::ConstTensor boxes_data =
320 const_boxes.tensor<float, 2>();
321 int candidate_idx = 0;
322 while (selected.size() < size_per_class &&
323 candidate_idx < candidate_vector.size()) {
324 next_candidate = candidate_vector[candidate_idx++];
325
326 // Overlapping boxes are likely to have similar scores,
327 // therefore we iterate through the previously selected boxes backwards
328 // in order to see if `next_candidate` should be suppressed.
329 bool should_select = true;
330 for (int j = selected.size() - 1; j >= 0; --j) {
331 if (IOUGreaterThanThreshold(boxes_data, next_candidate.box_index,
332 selected[j], iou_threshold)) {
333 should_select = false;
334 break;
335 }
336 }
337
338 if (should_select) {
339 selected.push_back(next_candidate.box_index);
340 // Add the selected box to the result candidate. Sorted by score
341 int id = next_candidate.box_index;
342 ResultCandidate rc = {next_candidate.box_index,
343 next_candidate.score,
344 class_idx,
345 {boxes_data(id, 0), boxes_data(id, 1),
346 boxes_data(id, 2), boxes_data(id, 3)}};
347 result_candidate_vec.push_back(rc);
348 }
349 }
350 }
351
352 auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
353 return rc_i.score > rc_j.score;
354 };
355 std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
356
357 int max_detections = 0;
358 // If pad_per_class is false, we always pad to max_total_size
359 if (!pad_per_class) {
360 max_detections =
361 std::min((int)result_candidate_vec.size(), total_size_per_batch);
362 per_batch_size = total_size_per_batch;
363 } else {
364 per_batch_size =
365 std::min(total_size_per_batch, max_size_per_class * num_classes);
366 max_detections =
367 std::min(per_batch_size, (int)result_candidate_vec.size());
368 }
369
370 final_valid_detections.push_back(max_detections);
371
372 int curr_total_size = max_detections;
373 int result_idx = 0;
374 // Pick the top max_detections values
375 while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
376 ResultCandidate next_candidate = result_candidate_vec[result_idx++];
377 // Add to final output vectors
378 nmsed_boxes[batch].push_back(
379 std::max(std::min(next_candidate.box_coord[0], clip_window[2]),
380 clip_window[0]));
381 nmsed_boxes[batch].push_back(
382 std::max(std::min(next_candidate.box_coord[1], clip_window[3]),
383 clip_window[1]));
384 nmsed_boxes[batch].push_back(
385 std::max(std::min(next_candidate.box_coord[2], clip_window[2]),
386 clip_window[0]));
387 nmsed_boxes[batch].push_back(
388 std::max(std::min(next_candidate.box_coord[3], clip_window[3]),
389 clip_window[1]));
390 nmsed_scores[batch].push_back(next_candidate.score);
391 nmsed_classes[batch].push_back(next_candidate.class_idx);
392 curr_total_size--;
393 }
394
395 nmsed_boxes[batch].resize(per_batch_size * 4, 0);
396 nmsed_scores[batch].resize(per_batch_size, 0);
397 nmsed_classes[batch].resize(per_batch_size, 0);
398 }
399
400 Tensor* nmsed_boxes_t = nullptr;
401 TensorShape boxes_shape({num_batches, per_batch_size, 4});
402 OP_REQUIRES_OK(context,
403 context->allocate_output(0, boxes_shape, &nmsed_boxes_t));
404 auto nmsed_boxes_flat = nmsed_boxes_t->template flat<float>();
405
406 Tensor* nmsed_scores_t = nullptr;
407 TensorShape scores_shape({num_batches, per_batch_size});
408 OP_REQUIRES_OK(context,
409 context->allocate_output(1, scores_shape, &nmsed_scores_t));
410 auto nmsed_scores_flat = nmsed_scores_t->template flat<float>();
411
412 Tensor* nmsed_classes_t = nullptr;
413 OP_REQUIRES_OK(context,
414 context->allocate_output(2, scores_shape, &nmsed_classes_t));
415 auto nmsed_classes_flat = nmsed_classes_t->template flat<float>();
416
417 Tensor* valid_detections_t = nullptr;
418 TensorShape valid_detections_shape({num_batches});
419 OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
420 &valid_detections_t));
421 auto valid_detections_flat = valid_detections_t->template flat<int>();
422
423 for (int i = 0; i < num_batches; ++i) {
424 valid_detections_flat(i) = final_valid_detections[i];
425 for (int j = 0; j < per_batch_size; ++j) {
426 nmsed_scores_flat(i * per_batch_size + j) = nmsed_scores[i][j];
427 nmsed_classes_flat(i * per_batch_size + j) = nmsed_classes[i][j];
428 for (int k = 0; k < 4; ++k) {
429 nmsed_boxes_flat(i * per_batch_size * 4 + j * 4 + k) =
430 nmsed_boxes[i][j * 4 + k];
431 }
432 }
433 }
434 }
435
436 } // namespace
437
438 template <typename Device>
439 class NonMaxSuppressionOp : public OpKernel {
440 public:
NonMaxSuppressionOp(OpKernelConstruction * context)441 explicit NonMaxSuppressionOp(OpKernelConstruction* context)
442 : OpKernel(context) {
443 OP_REQUIRES_OK(context, context->GetAttr("iou_threshold", &iou_threshold_));
444 }
445
Compute(OpKernelContext * context)446 void Compute(OpKernelContext* context) override {
447 // boxes: [num_boxes, 4]
448 const Tensor& boxes = context->input(0);
449 // scores: [num_boxes]
450 const Tensor& scores = context->input(1);
451 // max_output_size: scalar
452 const Tensor& max_output_size = context->input(2);
453 OP_REQUIRES(
454 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
455 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
456 max_output_size.shape().DebugString()));
457
458 OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
459 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
460 int num_boxes = 0;
461 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
462 CheckScoreSizes(context, num_boxes, scores);
463 if (!context->status().ok()) {
464 return;
465 }
466 auto suppress_check_fn =
467 CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_);
468
469 const float score_threshold_val = std::numeric_limits<float>::lowest();
470 DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
471 score_threshold_val, suppress_check_fn);
472 }
473
474 private:
475 float iou_threshold_;
476 };
477
478 template <typename Device, typename T>
479 class NonMaxSuppressionV2Op : public OpKernel {
480 public:
NonMaxSuppressionV2Op(OpKernelConstruction * context)481 explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
482 : OpKernel(context) {}
483
Compute(OpKernelContext * context)484 void Compute(OpKernelContext* context) override {
485 // boxes: [num_boxes, 4]
486 const Tensor& boxes = context->input(0);
487 // scores: [num_boxes]
488 const Tensor& scores = context->input(1);
489 // max_output_size: scalar
490 const Tensor& max_output_size = context->input(2);
491 OP_REQUIRES(
492 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
493 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
494 max_output_size.shape().DebugString()));
495 // iou_threshold: scalar
496 const Tensor& iou_threshold = context->input(3);
497 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
498 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
499 iou_threshold.shape().DebugString()));
500 const float iou_threshold_val = iou_threshold.scalar<float>()();
501
502 OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
503 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
504 int num_boxes = 0;
505 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
506 CheckScoreSizes(context, num_boxes, scores);
507 if (!context->status().ok()) {
508 return;
509 }
510 auto suppress_check_fn =
511 CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
512
513 const float score_threshold_val = std::numeric_limits<float>::lowest();
514 DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
515 score_threshold_val, suppress_check_fn);
516 }
517 };
518
519 template <typename Device, typename T>
520 class NonMaxSuppressionV3Op : public OpKernel {
521 public:
NonMaxSuppressionV3Op(OpKernelConstruction * context)522 explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
523 : OpKernel(context) {}
524
Compute(OpKernelContext * context)525 void Compute(OpKernelContext* context) override {
526 // boxes: [num_boxes, 4]
527 const Tensor& boxes = context->input(0);
528 // scores: [num_boxes]
529 const Tensor& scores = context->input(1);
530 // max_output_size: scalar
531 const Tensor& max_output_size = context->input(2);
532 OP_REQUIRES(
533 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
534 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
535 max_output_size.shape().DebugString()));
536 // iou_threshold: scalar
537 const Tensor& iou_threshold = context->input(3);
538 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
539 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
540 iou_threshold.shape().DebugString()));
541 const float iou_threshold_val = iou_threshold.scalar<float>()();
542 OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
543 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
544 // score_threshold: scalar
545 const Tensor& score_threshold = context->input(4);
546 OP_REQUIRES(
547 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
548 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
549 score_threshold.shape().DebugString()));
550 const float score_threshold_val = score_threshold.scalar<float>()();
551
552 int num_boxes = 0;
553 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
554 CheckScoreSizes(context, num_boxes, scores);
555 if (!context->status().ok()) {
556 return;
557 }
558
559 auto suppress_check_fn =
560 CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
561
562 DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
563 score_threshold_val, suppress_check_fn);
564 }
565 };
566
567 template <typename Device, typename T>
568 class NonMaxSuppressionV4Op : public OpKernel {
569 public:
NonMaxSuppressionV4Op(OpKernelConstruction * context)570 explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
571 : OpKernel(context) {
572 OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
573 &pad_to_max_output_size_));
574 }
575
Compute(OpKernelContext * context)576 void Compute(OpKernelContext* context) override {
577 // boxes: [num_boxes, 4]
578 const Tensor& boxes = context->input(0);
579 // scores: [num_boxes]
580 const Tensor& scores = context->input(1);
581 // max_output_size: scalar
582 const Tensor& max_output_size = context->input(2);
583 OP_REQUIRES(
584 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
585 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
586 max_output_size.shape().DebugString()));
587 // iou_threshold: scalar
588 const Tensor& iou_threshold = context->input(3);
589 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
590 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
591 iou_threshold.shape().DebugString()));
592 const float iou_threshold_val = iou_threshold.scalar<float>()();
593 OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
594 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
595 // score_threshold: scalar
596 const Tensor& score_threshold = context->input(4);
597 OP_REQUIRES(
598 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
599 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
600 score_threshold.shape().DebugString()));
601 const float score_threshold_val = score_threshold.scalar<float>()();
602
603 int num_boxes = 0;
604 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
605 CheckScoreSizes(context, num_boxes, scores);
606 if (!context->status().ok()) {
607 return;
608 }
609
610 auto suppress_check_fn =
611 CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
612 int num_valid_outputs;
613
614 DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
615 score_threshold_val, suppress_check_fn,
616 pad_to_max_output_size_, &num_valid_outputs);
617
618 // Allocate scalar output tensor for number of indices computed.
619 Tensor* num_outputs_t = nullptr;
620 OP_REQUIRES_OK(context, context->allocate_output(
621 1, tensorflow::TensorShape{}, &num_outputs_t));
622 num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
623 }
624
625 private:
626 bool pad_to_max_output_size_;
627 };
628
629 template <typename Device>
630 class NonMaxSuppressionWithOverlapsOp : public OpKernel {
631 public:
NonMaxSuppressionWithOverlapsOp(OpKernelConstruction * context)632 explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
633 : OpKernel(context) {}
634
Compute(OpKernelContext * context)635 void Compute(OpKernelContext* context) override {
636 // overlaps: [num_boxes, num_boxes]
637 const Tensor& overlaps = context->input(0);
638 // scores: [num_boxes]
639 const Tensor& scores = context->input(1);
640 // max_output_size: scalar
641 const Tensor& max_output_size = context->input(2);
642 OP_REQUIRES(
643 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
644 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
645 max_output_size.shape().DebugString()));
646 // overlap_threshold: scalar
647 const Tensor& overlap_threshold = context->input(3);
648 OP_REQUIRES(
649 context, TensorShapeUtils::IsScalar(overlap_threshold.shape()),
650 errors::InvalidArgument("overlap_threshold must be 0-D, got shape ",
651 overlap_threshold.shape().DebugString()));
652 const float overlap_threshold_val = overlap_threshold.scalar<float>()();
653
654 // score_threshold: scalar
655 const Tensor& score_threshold = context->input(4);
656 OP_REQUIRES(
657 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
658 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
659 score_threshold.shape().DebugString()));
660 const float score_threshold_val = score_threshold.scalar<float>()();
661
662 int num_boxes = 0;
663 ParseAndCheckOverlapSizes(context, overlaps, &num_boxes);
664 CheckScoreSizes(context, num_boxes, scores);
665 if (!context->status().ok()) {
666 return;
667 }
668 auto suppress_check_fn =
669 CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
670
671 DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
672 score_threshold_val, suppress_check_fn);
673 }
674 };
675
676 template <typename Device>
677 class CombinedNonMaxSuppressionOp : public OpKernel {
678 public:
CombinedNonMaxSuppressionOp(OpKernelConstruction * context)679 explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
680 : OpKernel(context) {
681 OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
682 }
683
Compute(OpKernelContext * context)684 void Compute(OpKernelContext* context) override {
685 // boxes: [batch_size, num_anchors, q, 4]
686 const Tensor& boxes = context->input(0);
687 // scores: [batch_size, num_anchors, num_classes]
688 const Tensor& scores = context->input(1);
689 OP_REQUIRES(
690 context, (boxes.dim_size(0) == scores.dim_size(0)),
691 errors::InvalidArgument("boxes and scores must have same batch size"));
692
693 // max_output_size: scalar
694 const Tensor& max_output_size = context->input(2);
695 OP_REQUIRES(
696 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
697 errors::InvalidArgument("max_size_per_class must be 0-D, got shape ",
698 max_output_size.shape().DebugString()));
699 const int max_size_per_class = max_output_size.scalar<int>()();
700 // max_total_size: scalar
701 const Tensor& max_total_size = context->input(3);
702 OP_REQUIRES(
703 context, TensorShapeUtils::IsScalar(max_total_size.shape()),
704 errors::InvalidArgument("max_total_size must be 0-D, got shape ",
705 max_total_size.shape().DebugString()));
706 const int max_total_size_per_batch = max_total_size.scalar<int>()();
707 OP_REQUIRES(context, max_total_size_per_batch > 0,
708 errors::InvalidArgument("max_total_size must be > 0"));
709 // iou_threshold: scalar
710 const Tensor& iou_threshold = context->input(4);
711 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
712 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
713 iou_threshold.shape().DebugString()));
714 const float iou_threshold_val = iou_threshold.scalar<float>()();
715
716 // score_threshold: scalar
717 const Tensor& score_threshold = context->input(5);
718 OP_REQUIRES(
719 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
720 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
721 score_threshold.shape().DebugString()));
722 const float score_threshold_val = score_threshold.scalar<float>()();
723
724 OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
725 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
726 int num_boxes = 0;
727 const int num_classes = scores.dim_size(2);
728 ParseAndCheckCombinedNMSBoxSizes(context, boxes, &num_boxes, num_classes);
729 CheckCombinedNMSScoreSizes(context, num_boxes, scores);
730
731 if (!context->status().ok()) {
732 return;
733 }
734 BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
735 max_size_per_class, max_total_size_per_batch,
736 score_threshold_val, iou_threshold_val,
737 pad_per_class_);
738 }
739
740 private:
741 bool pad_per_class_;
742 };
743
744 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
745 NonMaxSuppressionOp<CPUDevice>);
746
747 REGISTER_KERNEL_BUILDER(
748 Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
749 NonMaxSuppressionV2Op<CPUDevice, float>);
750 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
751 .TypeConstraint<Eigen::half>("T")
752 .Device(DEVICE_CPU),
753 NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
754
755 REGISTER_KERNEL_BUILDER(
756 Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
757 NonMaxSuppressionV3Op<CPUDevice, float>);
758 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
759 .TypeConstraint<Eigen::half>("T")
760 .Device(DEVICE_CPU),
761 NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
762
763 REGISTER_KERNEL_BUILDER(
764 Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
765 NonMaxSuppressionV4Op<CPUDevice, float>);
766 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
767 .TypeConstraint<Eigen::half>("T")
768 .Device(DEVICE_CPU),
769 NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
770
771 REGISTER_KERNEL_BUILDER(
772 Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
773 NonMaxSuppressionWithOverlapsOp<CPUDevice>);
774
775 REGISTER_KERNEL_BUILDER(Name("CombinedNonMaxSuppression").Device(DEVICE_CPU),
776 CombinedNonMaxSuppressionOp<CPUDevice>);
777
778 } // namespace tensorflow
779