1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/non_max_suppression_op.h"
21 
22 #include <functional>
23 #include <queue>
24 #include <vector>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/gtl/stl_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace tensorflow {
38 namespace {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 
CheckScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)42 static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
43                                    const Tensor& scores) {
44   // The shape of 'scores' is [num_boxes]
45   OP_REQUIRES(context, scores.dims() == 1,
46               errors::InvalidArgument("scores must be 1-D",
47                                       scores.shape().DebugString()));
48   OP_REQUIRES(context, scores.dim_size(0) == num_boxes,
49               errors::InvalidArgument("scores has incompatible shape"));
50 }
51 
ParseAndCheckOverlapSizes(OpKernelContext * context,const Tensor & overlaps,int * num_boxes)52 static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
53                                              const Tensor& overlaps,
54                                              int* num_boxes) {
55   // the shape of 'overlaps' is [num_boxes, num_boxes]
56   OP_REQUIRES(context, overlaps.dims() == 2,
57               errors::InvalidArgument("overlaps must be 2-D",
58                                       overlaps.shape().DebugString()));
59 
60   *num_boxes = overlaps.dim_size(0);
61   OP_REQUIRES(context, overlaps.dim_size(1) == *num_boxes,
62               errors::InvalidArgument("overlaps must be square",
63                                       overlaps.shape().DebugString()));
64 }
65 
ParseAndCheckBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes)66 static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
67                                          const Tensor& boxes, int* num_boxes) {
68   // The shape of 'boxes' is [num_boxes, 4]
69   OP_REQUIRES(context, boxes.dims() == 2,
70               errors::InvalidArgument("boxes must be 2-D",
71                                       boxes.shape().DebugString()));
72   *num_boxes = boxes.dim_size(0);
73   OP_REQUIRES(context, boxes.dim_size(1) == 4,
74               errors::InvalidArgument("boxes must have 4 columns"));
75 }
76 
CheckCombinedNMSScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)77 static inline void CheckCombinedNMSScoreSizes(OpKernelContext* context,
78                                               int num_boxes,
79                                               const Tensor& scores) {
80   // The shape of 'scores' is [batch_size, num_boxes, num_classes]
81   OP_REQUIRES(context, scores.dims() == 3,
82               errors::InvalidArgument("scores must be 3-D",
83                                       scores.shape().DebugString()));
84   OP_REQUIRES(context, scores.dim_size(1) == num_boxes,
85               errors::InvalidArgument("scores has incompatible shape"));
86 }
87 
ParseAndCheckCombinedNMSBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes,const int num_classes)88 static inline void ParseAndCheckCombinedNMSBoxSizes(OpKernelContext* context,
89                                                     const Tensor& boxes,
90                                                     int* num_boxes,
91                                                     const int num_classes) {
92   // The shape of 'boxes' is [batch_size, num_boxes, q, 4]
93   OP_REQUIRES(context, boxes.dims() == 4,
94               errors::InvalidArgument("boxes must be 4-D",
95                                       boxes.shape().DebugString()));
96 
97   bool box_check = boxes.dim_size(2) == 1 || boxes.dim_size(2) == num_classes;
98   OP_REQUIRES(context, box_check,
99               errors::InvalidArgument(
100                   "third dimension of boxes must be either 1 or num classes"));
101   *num_boxes = boxes.dim_size(1);
102   OP_REQUIRES(context, boxes.dim_size(3) == 4,
103               errors::InvalidArgument("boxes must have 4 columns"));
104 }
105 // Return intersection-over-union overlap between boxes i and j
106 template <typename T>
IOUGreaterThanThreshold(typename TTypes<T,2>::ConstTensor boxes,int i,int j,T iou_threshold)107 static inline bool IOUGreaterThanThreshold(
108     typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) {
109   const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2));
110   const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3));
111   const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2));
112   const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3));
113   const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2));
114   const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3));
115   const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2));
116   const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3));
117   const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
118   const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
119   if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0;
120   const T intersection_ymin = std::max<T>(ymin_i, ymin_j);
121   const T intersection_xmin = std::max<T>(xmin_i, xmin_j);
122   const T intersection_ymax = std::min<T>(ymax_i, ymax_j);
123   const T intersection_xmax = std::min<T>(xmax_i, xmax_j);
124   const T intersection_area =
125       std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
126       std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
127   const T iou = intersection_area / (area_i + area_j - intersection_area);
128   return iou > iou_threshold;
129 }
130 
OverlapsGreaterThanThreshold(typename TTypes<float,2>::ConstTensor overlaps,int i,int j,float overlap_threshold)131 static inline bool OverlapsGreaterThanThreshold(
132     typename TTypes<float, 2>::ConstTensor overlaps, int i, int j,
133     float overlap_threshold) {
134   return overlaps(i, j) > overlap_threshold;
135 }
136 
137 template <typename T>
CreateIOUSuppressCheckFn(const Tensor & boxes,float threshold)138 static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
139     const Tensor& boxes, float threshold) {
140   typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
141   return std::bind(&IOUGreaterThanThreshold<T>, boxes_data,
142                    std::placeholders::_1, std::placeholders::_2,
143                    static_cast<T>(threshold));
144 }
145 
CreateOverlapsSuppressCheckFn(const Tensor & overlaps,float threshold)146 static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
147     const Tensor& overlaps, float threshold) {
148   typename TTypes<float, 2>::ConstTensor overlaps_data =
149       overlaps.tensor<float, 2>();
150   return std::bind(&OverlapsGreaterThanThreshold, overlaps_data,
151                    std::placeholders::_1, std::placeholders::_2, threshold);
152 }
153 
154 template <typename T>
DoNonMaxSuppressionOp(OpKernelContext * context,const Tensor & scores,int num_boxes,const Tensor & max_output_size,const float score_threshold,const std::function<bool (int,int)> & suppress_check_fn,bool pad_to_max_output_size=false,int * ptr_num_valid_outputs=nullptr)155 void DoNonMaxSuppressionOp(
156     OpKernelContext* context, const Tensor& scores, int num_boxes,
157     const Tensor& max_output_size, const float score_threshold,
158     const std::function<bool(int, int)>& suppress_check_fn,
159     bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
160   const int output_size = max_output_size.scalar<int>()();
161 
162   std::vector<T> scores_data(num_boxes);
163   std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
164 
165   // Data structure for selection candidate in NMS.
166   struct Candidate {
167     int box_index;
168     T score;
169   };
170 
171   auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
172     return bs_i.score < bs_j.score;
173   };
174   std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
175       candidate_priority_queue(cmp);
176   for (int i = 0; i < scores_data.size(); ++i) {
177     if (static_cast<float>(scores_data[i]) > score_threshold) {
178       candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
179     }
180   }
181 
182   std::vector<int> selected;
183   std::vector<T> selected_scores;
184   Candidate next_candidate;
185 
186   while (selected.size() < output_size && !candidate_priority_queue.empty()) {
187     next_candidate = candidate_priority_queue.top();
188     candidate_priority_queue.pop();
189 
190     // Overlapping boxes are likely to have similar scores,
191     // therefore we iterate through the previously selected boxes backwards
192     // in order to see if `next_candidate` should be suppressed.
193     bool should_select = true;
194 
195     for (int j = static_cast<int>(selected.size()) - 1; j >= 0; --j) {
196       if (suppress_check_fn(next_candidate.box_index, selected[j])) {
197         should_select = false;
198         break;
199       }
200     }
201 
202     if (should_select) {
203       selected.push_back(next_candidate.box_index);
204       selected_scores.push_back(next_candidate.score);
205     }
206   }
207 
208   int num_valid_outputs = selected.size();
209   if (pad_to_max_output_size) {
210     selected.resize(output_size, 0);
211     selected_scores.resize(output_size, static_cast<T>(0));
212   }
213   if (ptr_num_valid_outputs) {
214     *ptr_num_valid_outputs = num_valid_outputs;
215   }
216 
217   // Allocate output tensors
218   Tensor* output_indices = nullptr;
219   TensorShape output_shape({static_cast<int>(selected.size())});
220   OP_REQUIRES_OK(context,
221                  context->allocate_output(0, output_shape, &output_indices));
222   TTypes<int, 1>::Tensor output_indices_data = output_indices->tensor<int, 1>();
223   std::copy_n(selected.begin(), selected.size(), output_indices_data.data());
224 }
225 
BatchedNonMaxSuppressionOp(OpKernelContext * context,const Tensor & inp_boxes,const Tensor & inp_scores,int num_boxes,const int max_size_per_class,const int total_size_per_batch,const float score_threshold,const float iou_threshold,bool pad_per_class=false)226 void BatchedNonMaxSuppressionOp(
227     OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
228     int num_boxes, const int max_size_per_class, const int total_size_per_batch,
229     const float score_threshold, const float iou_threshold,
230     bool pad_per_class = false) {
231   int q = inp_boxes.dim_size(2);
232   int num_classes = inp_scores.dim_size(2);
233   const int num_batches = inp_boxes.dim_size(0);
234 
235   // Default clip window of [0, 0, 1, 1] if none specified
236   std::vector<float> clip_window{0, 0, 1, 1};
237 
238   // [num_batches, per_batch_size * 4]
239   std::vector<std::vector<float>> nmsed_boxes(num_batches);
240   // [num_batches, per_batch_size]
241   std::vector<std::vector<float>> nmsed_scores(num_batches);
242   // [num_batches, per_batch_size]
243   std::vector<std::vector<float>> nmsed_classes(num_batches);
244   // [num_batches]
245   std::vector<int> final_valid_detections;
246 
247   int per_batch_size = total_size_per_batch;
248 
249   // perform non_max_suppression operation for each batch independently
250   for (int batch = 0; batch < num_batches; ++batch) {
251     // dims of per_batch_boxes [num_boxes, q, 4]
252     Tensor per_batch_boxes = inp_boxes.Slice(batch, batch + 1);
253     // dims of per_batch_scores [num_boxes, num_classes]
254     Tensor per_batch_scores = inp_scores.Slice(batch, batch + 1);
255 
256     struct ResultCandidate {
257       int box_index;
258       float score;
259       int class_idx;
260       float box_coord[4];
261     };
262 
263     std::vector<ResultCandidate> result_candidate_vec;
264 
265     float* scores_data = per_batch_scores.unaligned_flat<float>().data();
266     float* boxes_data = per_batch_boxes.unaligned_flat<float>().data();
267 
268     // Iterate through all classes
269     for (int class_idx = 0; class_idx < num_classes; ++class_idx) {
270       std::vector<float> class_scores_data;
271       class_scores_data.reserve(num_boxes);
272       std::vector<float> class_boxes_data;
273       class_boxes_data.reserve(num_boxes * 4);
274 
275       for (int box = 0; box < num_boxes; ++box) {
276         // Get the scores per class
277         // class_scores_data dim is [num_boxes].
278         class_scores_data.push_back(scores_data[box * num_classes + class_idx]);
279         for (int cid = 0; cid < 4; ++cid) {
280           if (q > 1) {
281             // Get the boxes per class. class_boxes_data dims is [num_boxes, 4]
282             class_boxes_data.push_back(
283                 boxes_data[(box * q + class_idx) * 4 + cid]);
284           } else {
285             class_boxes_data.push_back(boxes_data[box * 4 + cid]);
286           }
287         }
288       }
289 
290       // Copy class_boxes_data to a tensor
291       TensorShape boxesShape({num_boxes, 4});
292       Tensor boxes(per_batch_boxes.dtype(), boxesShape);
293       std::copy_n(class_boxes_data.begin(), class_boxes_data.size(),
294                   boxes.unaligned_flat<float>().data());
295 
296       const int size_per_class = std::min(max_size_per_class, num_boxes);
297       // Do NMS, get the candidate indices of form vector<int>
298       // Data structure for selection candidate in NMS.
299       struct Candidate {
300         int box_index;
301         float score;
302       };
303       auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
304         return bs_i.score > bs_j.score;
305       };
306       std::vector<Candidate> candidate_vector;
307       for (int i = 0; i < class_scores_data.size(); ++i) {
308         if (class_scores_data[i] > score_threshold) {
309           candidate_vector.emplace_back(Candidate({i, class_scores_data[i]}));
310         }
311       }
312 
313       std::vector<int> selected;
314       std::vector<float> selected_boxes;
315       Candidate next_candidate;
316 
317       std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
318       const Tensor const_boxes = boxes;
319       typename TTypes<float, 2>::ConstTensor boxes_data =
320           const_boxes.tensor<float, 2>();
321       int candidate_idx = 0;
322       while (selected.size() < size_per_class &&
323              candidate_idx < candidate_vector.size()) {
324         next_candidate = candidate_vector[candidate_idx++];
325 
326         // Overlapping boxes are likely to have similar scores,
327         // therefore we iterate through the previously selected boxes backwards
328         // in order to see if `next_candidate` should be suppressed.
329         bool should_select = true;
330         for (int j = selected.size() - 1; j >= 0; --j) {
331           if (IOUGreaterThanThreshold(boxes_data, next_candidate.box_index,
332                                       selected[j], iou_threshold)) {
333             should_select = false;
334             break;
335           }
336         }
337 
338         if (should_select) {
339           selected.push_back(next_candidate.box_index);
340           // Add the selected box to the result candidate. Sorted by score
341           int id = next_candidate.box_index;
342           ResultCandidate rc = {next_candidate.box_index,
343                                 next_candidate.score,
344                                 class_idx,
345                                 {boxes_data(id, 0), boxes_data(id, 1),
346                                  boxes_data(id, 2), boxes_data(id, 3)}};
347           result_candidate_vec.push_back(rc);
348         }
349       }
350     }
351 
352     auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
353       return rc_i.score > rc_j.score;
354     };
355     std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
356 
357     int max_detections = 0;
358     // If pad_per_class is false, we always pad to max_total_size
359     if (!pad_per_class) {
360       max_detections =
361           std::min((int)result_candidate_vec.size(), total_size_per_batch);
362       per_batch_size = total_size_per_batch;
363     } else {
364       per_batch_size =
365           std::min(total_size_per_batch, max_size_per_class * num_classes);
366       max_detections =
367           std::min(per_batch_size, (int)result_candidate_vec.size());
368     }
369 
370     final_valid_detections.push_back(max_detections);
371 
372     int curr_total_size = max_detections;
373     int result_idx = 0;
374     // Pick the top max_detections values
375     while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
376       ResultCandidate next_candidate = result_candidate_vec[result_idx++];
377       // Add to final output vectors
378       nmsed_boxes[batch].push_back(
379           std::max(std::min(next_candidate.box_coord[0], clip_window[2]),
380                    clip_window[0]));
381       nmsed_boxes[batch].push_back(
382           std::max(std::min(next_candidate.box_coord[1], clip_window[3]),
383                    clip_window[1]));
384       nmsed_boxes[batch].push_back(
385           std::max(std::min(next_candidate.box_coord[2], clip_window[2]),
386                    clip_window[0]));
387       nmsed_boxes[batch].push_back(
388           std::max(std::min(next_candidate.box_coord[3], clip_window[3]),
389                    clip_window[1]));
390       nmsed_scores[batch].push_back(next_candidate.score);
391       nmsed_classes[batch].push_back(next_candidate.class_idx);
392       curr_total_size--;
393     }
394 
395     nmsed_boxes[batch].resize(per_batch_size * 4, 0);
396     nmsed_scores[batch].resize(per_batch_size, 0);
397     nmsed_classes[batch].resize(per_batch_size, 0);
398   }
399 
400   Tensor* nmsed_boxes_t = nullptr;
401   TensorShape boxes_shape({num_batches, per_batch_size, 4});
402   OP_REQUIRES_OK(context,
403                  context->allocate_output(0, boxes_shape, &nmsed_boxes_t));
404   auto nmsed_boxes_flat = nmsed_boxes_t->template flat<float>();
405 
406   Tensor* nmsed_scores_t = nullptr;
407   TensorShape scores_shape({num_batches, per_batch_size});
408   OP_REQUIRES_OK(context,
409                  context->allocate_output(1, scores_shape, &nmsed_scores_t));
410   auto nmsed_scores_flat = nmsed_scores_t->template flat<float>();
411 
412   Tensor* nmsed_classes_t = nullptr;
413   OP_REQUIRES_OK(context,
414                  context->allocate_output(2, scores_shape, &nmsed_classes_t));
415   auto nmsed_classes_flat = nmsed_classes_t->template flat<float>();
416 
417   Tensor* valid_detections_t = nullptr;
418   TensorShape valid_detections_shape({num_batches});
419   OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
420                                                    &valid_detections_t));
421   auto valid_detections_flat = valid_detections_t->template flat<int>();
422 
423   for (int i = 0; i < num_batches; ++i) {
424     valid_detections_flat(i) = final_valid_detections[i];
425     for (int j = 0; j < per_batch_size; ++j) {
426       nmsed_scores_flat(i * per_batch_size + j) = nmsed_scores[i][j];
427       nmsed_classes_flat(i * per_batch_size + j) = nmsed_classes[i][j];
428       for (int k = 0; k < 4; ++k) {
429         nmsed_boxes_flat(i * per_batch_size * 4 + j * 4 + k) =
430             nmsed_boxes[i][j * 4 + k];
431       }
432     }
433   }
434 }
435 
436 }  // namespace
437 
438 template <typename Device>
439 class NonMaxSuppressionOp : public OpKernel {
440  public:
NonMaxSuppressionOp(OpKernelConstruction * context)441   explicit NonMaxSuppressionOp(OpKernelConstruction* context)
442       : OpKernel(context) {
443     OP_REQUIRES_OK(context, context->GetAttr("iou_threshold", &iou_threshold_));
444   }
445 
Compute(OpKernelContext * context)446   void Compute(OpKernelContext* context) override {
447     // boxes: [num_boxes, 4]
448     const Tensor& boxes = context->input(0);
449     // scores: [num_boxes]
450     const Tensor& scores = context->input(1);
451     // max_output_size: scalar
452     const Tensor& max_output_size = context->input(2);
453     OP_REQUIRES(
454         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
455         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
456                                 max_output_size.shape().DebugString()));
457 
458     OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
459                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
460     int num_boxes = 0;
461     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
462     CheckScoreSizes(context, num_boxes, scores);
463     if (!context->status().ok()) {
464       return;
465     }
466     auto suppress_check_fn =
467         CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_);
468 
469     const float score_threshold_val = std::numeric_limits<float>::lowest();
470     DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
471                                  score_threshold_val, suppress_check_fn);
472   }
473 
474  private:
475   float iou_threshold_;
476 };
477 
478 template <typename Device, typename T>
479 class NonMaxSuppressionV2Op : public OpKernel {
480  public:
NonMaxSuppressionV2Op(OpKernelConstruction * context)481   explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
482       : OpKernel(context) {}
483 
Compute(OpKernelContext * context)484   void Compute(OpKernelContext* context) override {
485     // boxes: [num_boxes, 4]
486     const Tensor& boxes = context->input(0);
487     // scores: [num_boxes]
488     const Tensor& scores = context->input(1);
489     // max_output_size: scalar
490     const Tensor& max_output_size = context->input(2);
491     OP_REQUIRES(
492         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
493         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
494                                 max_output_size.shape().DebugString()));
495     // iou_threshold: scalar
496     const Tensor& iou_threshold = context->input(3);
497     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
498                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
499                                         iou_threshold.shape().DebugString()));
500     const float iou_threshold_val = iou_threshold.scalar<float>()();
501 
502     OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
503                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
504     int num_boxes = 0;
505     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
506     CheckScoreSizes(context, num_boxes, scores);
507     if (!context->status().ok()) {
508       return;
509     }
510     auto suppress_check_fn =
511         CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
512 
513     const float score_threshold_val = std::numeric_limits<float>::lowest();
514     DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
515                              score_threshold_val, suppress_check_fn);
516   }
517 };
518 
519 template <typename Device, typename T>
520 class NonMaxSuppressionV3Op : public OpKernel {
521  public:
NonMaxSuppressionV3Op(OpKernelConstruction * context)522   explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
523       : OpKernel(context) {}
524 
Compute(OpKernelContext * context)525   void Compute(OpKernelContext* context) override {
526     // boxes: [num_boxes, 4]
527     const Tensor& boxes = context->input(0);
528     // scores: [num_boxes]
529     const Tensor& scores = context->input(1);
530     // max_output_size: scalar
531     const Tensor& max_output_size = context->input(2);
532     OP_REQUIRES(
533         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
534         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
535                                 max_output_size.shape().DebugString()));
536     // iou_threshold: scalar
537     const Tensor& iou_threshold = context->input(3);
538     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
539                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
540                                         iou_threshold.shape().DebugString()));
541     const float iou_threshold_val = iou_threshold.scalar<float>()();
542     OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
543                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
544     // score_threshold: scalar
545     const Tensor& score_threshold = context->input(4);
546     OP_REQUIRES(
547         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
548         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
549                                 score_threshold.shape().DebugString()));
550     const float score_threshold_val = score_threshold.scalar<float>()();
551 
552     int num_boxes = 0;
553     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
554     CheckScoreSizes(context, num_boxes, scores);
555     if (!context->status().ok()) {
556       return;
557     }
558 
559     auto suppress_check_fn =
560         CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
561 
562     DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
563                              score_threshold_val, suppress_check_fn);
564   }
565 };
566 
567 template <typename Device, typename T>
568 class NonMaxSuppressionV4Op : public OpKernel {
569  public:
NonMaxSuppressionV4Op(OpKernelConstruction * context)570   explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
571       : OpKernel(context) {
572     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
573                                              &pad_to_max_output_size_));
574   }
575 
Compute(OpKernelContext * context)576   void Compute(OpKernelContext* context) override {
577     // boxes: [num_boxes, 4]
578     const Tensor& boxes = context->input(0);
579     // scores: [num_boxes]
580     const Tensor& scores = context->input(1);
581     // max_output_size: scalar
582     const Tensor& max_output_size = context->input(2);
583     OP_REQUIRES(
584         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
585         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
586                                 max_output_size.shape().DebugString()));
587     // iou_threshold: scalar
588     const Tensor& iou_threshold = context->input(3);
589     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
590                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
591                                         iou_threshold.shape().DebugString()));
592     const float iou_threshold_val = iou_threshold.scalar<float>()();
593     OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
594                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
595     // score_threshold: scalar
596     const Tensor& score_threshold = context->input(4);
597     OP_REQUIRES(
598         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
599         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
600                                 score_threshold.shape().DebugString()));
601     const float score_threshold_val = score_threshold.scalar<float>()();
602 
603     int num_boxes = 0;
604     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
605     CheckScoreSizes(context, num_boxes, scores);
606     if (!context->status().ok()) {
607       return;
608     }
609 
610     auto suppress_check_fn =
611         CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
612     int num_valid_outputs;
613 
614     DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
615                              score_threshold_val, suppress_check_fn,
616                              pad_to_max_output_size_, &num_valid_outputs);
617 
618     // Allocate scalar output tensor for number of indices computed.
619     Tensor* num_outputs_t = nullptr;
620     OP_REQUIRES_OK(context, context->allocate_output(
621                                 1, tensorflow::TensorShape{}, &num_outputs_t));
622     num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
623   }
624 
625  private:
626   bool pad_to_max_output_size_;
627 };
628 
629 template <typename Device>
630 class NonMaxSuppressionWithOverlapsOp : public OpKernel {
631  public:
NonMaxSuppressionWithOverlapsOp(OpKernelConstruction * context)632   explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
633       : OpKernel(context) {}
634 
Compute(OpKernelContext * context)635   void Compute(OpKernelContext* context) override {
636     // overlaps: [num_boxes, num_boxes]
637     const Tensor& overlaps = context->input(0);
638     // scores: [num_boxes]
639     const Tensor& scores = context->input(1);
640     // max_output_size: scalar
641     const Tensor& max_output_size = context->input(2);
642     OP_REQUIRES(
643         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
644         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
645                                 max_output_size.shape().DebugString()));
646     // overlap_threshold: scalar
647     const Tensor& overlap_threshold = context->input(3);
648     OP_REQUIRES(
649         context, TensorShapeUtils::IsScalar(overlap_threshold.shape()),
650         errors::InvalidArgument("overlap_threshold must be 0-D, got shape ",
651                                 overlap_threshold.shape().DebugString()));
652     const float overlap_threshold_val = overlap_threshold.scalar<float>()();
653 
654     // score_threshold: scalar
655     const Tensor& score_threshold = context->input(4);
656     OP_REQUIRES(
657         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
658         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
659                                 score_threshold.shape().DebugString()));
660     const float score_threshold_val = score_threshold.scalar<float>()();
661 
662     int num_boxes = 0;
663     ParseAndCheckOverlapSizes(context, overlaps, &num_boxes);
664     CheckScoreSizes(context, num_boxes, scores);
665     if (!context->status().ok()) {
666       return;
667     }
668     auto suppress_check_fn =
669         CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
670 
671     DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
672                                  score_threshold_val, suppress_check_fn);
673   }
674 };
675 
676 template <typename Device>
677 class CombinedNonMaxSuppressionOp : public OpKernel {
678  public:
CombinedNonMaxSuppressionOp(OpKernelConstruction * context)679   explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
680       : OpKernel(context) {
681     OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
682   }
683 
Compute(OpKernelContext * context)684   void Compute(OpKernelContext* context) override {
685     // boxes: [batch_size, num_anchors, q, 4]
686     const Tensor& boxes = context->input(0);
687     // scores: [batch_size, num_anchors, num_classes]
688     const Tensor& scores = context->input(1);
689     OP_REQUIRES(
690         context, (boxes.dim_size(0) == scores.dim_size(0)),
691         errors::InvalidArgument("boxes and scores must have same batch size"));
692 
693     // max_output_size: scalar
694     const Tensor& max_output_size = context->input(2);
695     OP_REQUIRES(
696         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
697         errors::InvalidArgument("max_size_per_class must be 0-D, got shape ",
698                                 max_output_size.shape().DebugString()));
699     const int max_size_per_class = max_output_size.scalar<int>()();
700     // max_total_size: scalar
701     const Tensor& max_total_size = context->input(3);
702     OP_REQUIRES(
703         context, TensorShapeUtils::IsScalar(max_total_size.shape()),
704         errors::InvalidArgument("max_total_size must be 0-D, got shape ",
705                                 max_total_size.shape().DebugString()));
706     const int max_total_size_per_batch = max_total_size.scalar<int>()();
707     OP_REQUIRES(context, max_total_size_per_batch > 0,
708                 errors::InvalidArgument("max_total_size must be > 0"));
709     // iou_threshold: scalar
710     const Tensor& iou_threshold = context->input(4);
711     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
712                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
713                                         iou_threshold.shape().DebugString()));
714     const float iou_threshold_val = iou_threshold.scalar<float>()();
715 
716     // score_threshold: scalar
717     const Tensor& score_threshold = context->input(5);
718     OP_REQUIRES(
719         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
720         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
721                                 score_threshold.shape().DebugString()));
722     const float score_threshold_val = score_threshold.scalar<float>()();
723 
724     OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
725                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
726     int num_boxes = 0;
727     const int num_classes = scores.dim_size(2);
728     ParseAndCheckCombinedNMSBoxSizes(context, boxes, &num_boxes, num_classes);
729     CheckCombinedNMSScoreSizes(context, num_boxes, scores);
730 
731     if (!context->status().ok()) {
732       return;
733     }
734     BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
735                                max_size_per_class, max_total_size_per_batch,
736                                score_threshold_val, iou_threshold_val,
737                                pad_per_class_);
738   }
739 
740  private:
741   bool pad_per_class_;
742 };
743 
744 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
745                         NonMaxSuppressionOp<CPUDevice>);
746 
747 REGISTER_KERNEL_BUILDER(
748     Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
749     NonMaxSuppressionV2Op<CPUDevice, float>);
750 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
751                             .TypeConstraint<Eigen::half>("T")
752                             .Device(DEVICE_CPU),
753                         NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
754 
755 REGISTER_KERNEL_BUILDER(
756     Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
757     NonMaxSuppressionV3Op<CPUDevice, float>);
758 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
759                             .TypeConstraint<Eigen::half>("T")
760                             .Device(DEVICE_CPU),
761                         NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
762 
763 REGISTER_KERNEL_BUILDER(
764     Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
765     NonMaxSuppressionV4Op<CPUDevice, float>);
766 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
767                             .TypeConstraint<Eigen::half>("T")
768                             .Device(DEVICE_CPU),
769                         NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
770 
771 REGISTER_KERNEL_BUILDER(
772     Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
773     NonMaxSuppressionWithOverlapsOp<CPUDevice>);
774 
775 REGISTER_KERNEL_BUILDER(Name("CombinedNonMaxSuppression").Device(DEVICE_CPU),
776                         CombinedNonMaxSuppressionOp<CPUDevice>);
777 
778 }  // namespace tensorflow
779