1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/types/span.h"
17 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
18 #include "tensorflow/compiler/tf2xla/lib/util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/loops.h"
26 #include "tensorflow/compiler/xla/client/lib/sorting.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 // Converts 'input' from RGB format to HSV format.
37 // 'shape' is the shape of the red/green/blue tensors.
RGBToHSV(XlaOpKernelContext * ctx,xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & rgb,DataType dtype,const TensorShape & shape)38 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
39                                    const std::array<xla::XlaOp, 3>& rgb,
40                                    DataType dtype, const TensorShape& shape) {
41   auto zero = XlaHelpers::Zero(b, dtype);
42   auto one = XlaHelpers::One(b, dtype);
43 
44   auto red = rgb[0];
45   auto green = rgb[1];
46   auto blue = rgb[2];
47   auto value = xla::Max(xla::Max(red, green), blue);
48   auto minimum = xla::Min(xla::Min(red, green), blue);
49   auto range = xla::Sub(value, minimum);
50 
51   auto zeros = xla::Broadcast(zero, shape.dim_sizes());
52   auto saturation =
53       xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
54 
55   auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
56 
57   auto hue =
58       xla::Select(xla::Eq(green, value),
59                   xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
60                            XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
61                   xla::Add(xla::Mul(norm, xla::Sub(red, green)),
62                            XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
63   hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
64                     hue);
65   hue = xla::Select(xla::Gt(range, zero), hue, zeros);
66   hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
67   return {hue, saturation, value};
68 }
69 
70 // Converts 'input' from HSV format to RGB format.
HSVToRGB(xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & hsv,DataType dtype)71 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
72                                    const std::array<xla::XlaOp, 3>& hsv,
73                                    DataType dtype) {
74   xla::XlaOp hue = hsv[0];
75   xla::XlaOp saturation = hsv[1];
76   xla::XlaOp value = hsv[2];
77   auto zero = XlaHelpers::Zero(b, dtype);
78   auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
79   auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
80   auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
81   auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
82   auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
83 
84   auto dh = xla::Mul(hue, six);
85   auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
86   auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
87   auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
88   auto one_minus_s = xla::Sub(one, saturation);
89 
90   auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
91   auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
92   auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
93   return {red, green, blue};
94 }
95 
96 class RGBToHSVOp : public XlaOpKernel {
97  public:
RGBToHSVOp(OpKernelConstruction * context)98   explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
99 
Compile(XlaOpKernelContext * context)100   void Compile(XlaOpKernelContext* context) override {
101     const TensorShape input_shape = context->InputShape(0);
102     OP_REQUIRES(context, input_shape.dims() >= 1,
103                 errors::InvalidArgument("input must be at least 1D",
104                                         input_shape.DebugString()));
105     int channel_dim = input_shape.dims() - 1;
106     int64 channels = input_shape.dim_size(channel_dim);
107     OP_REQUIRES(
108         context, channels == 3,
109         errors::FailedPrecondition("input must have 3 channels but input has ",
110                                    channels, " channels."));
111 
112     xla::XlaBuilder* b = context->builder();
113     xla::XlaOp input = context->Input(0);
114 
115     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
116                                      /*limit_index=*/1, /*stride=*/1,
117                                      /*dimno=*/channel_dim);
118     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
119                                        /*limit_index=*/2, /*stride=*/1,
120                                        /*dimno=*/channel_dim);
121     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
122                                       /*limit_index=*/3, /*stride=*/1,
123                                       /*dimno=*/channel_dim);
124     TensorShape channel_shape = input_shape;
125     channel_shape.set_dim(channel_dim, 1);
126     auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
127                         channel_shape);
128 
129     context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
130   }
131 };
132 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
133 
134 class HSVToRGBOp : public XlaOpKernel {
135  public:
HSVToRGBOp(OpKernelConstruction * context)136   explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
137 
Compile(XlaOpKernelContext * context)138   void Compile(XlaOpKernelContext* context) override {
139     const TensorShape input_shape = context->InputShape(0);
140     OP_REQUIRES(context, input_shape.dims() >= 1,
141                 errors::InvalidArgument("input must be at least 1D",
142                                         input_shape.DebugString()));
143     int channel_dim = input_shape.dims() - 1;
144     int64 channels = input_shape.dim_size(channel_dim);
145     OP_REQUIRES(
146         context, channels == 3,
147         errors::FailedPrecondition("input must have 3 channels but input has ",
148                                    channels, " channels."));
149 
150     xla::XlaBuilder* b = context->builder();
151     xla::XlaOp input = context->Input(0);
152     xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
153                                      /*limit_index=*/1, /*stride=*/1,
154                                      /*dimno=*/channel_dim);
155     xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
156                                             /*limit_index=*/2, /*stride=*/1,
157                                             /*dimno=*/channel_dim);
158     xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
159                                        /*limit_index=*/3, /*stride=*/1,
160                                        /*dimno=*/channel_dim);
161 
162     auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
163                         context->input_type(0));
164 
165     context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
166   }
167 };
168 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
169 
170 class AdjustContrastOpV2 : public XlaOpKernel {
171  public:
AdjustContrastOpV2(OpKernelConstruction * context)172   explicit AdjustContrastOpV2(OpKernelConstruction* context)
173       : XlaOpKernel(context) {}
174 
Compile(XlaOpKernelContext * context)175   void Compile(XlaOpKernelContext* context) override {
176     const TensorShape& input_shape = context->InputShape(0);
177     const TensorShape& factor_shape = context->InputShape(1);
178     OP_REQUIRES(context, input_shape.dims() >= 3,
179                 errors::InvalidArgument("input must be at least 3-D, got shape",
180                                         input_shape.DebugString()));
181     int height_dim = input_shape.dims() - 3;
182     int width_dim = input_shape.dims() - 2;
183     int channel_dim = input_shape.dims() - 1;
184     const int64 height = input_shape.dim_size(height_dim);
185     const int64 width = input_shape.dim_size(width_dim);
186 
187     OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
188                 errors::InvalidArgument("contrast_factor must be scalar: ",
189                                         factor_shape.DebugString()));
190 
191     xla::XlaBuilder* b = context->builder();
192     DataType type = context->input_type(0);
193 
194     xla::XlaOp input = context->Input(0);
195     xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type);
196 
197     const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
198     auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
199     auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
200                               *context->GetOrCreateAdd(accumulation_type),
201                               {height_dim, width_dim});
202 
203     auto output = xla::Div(
204         reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width));
205     output = XlaHelpers::ConvertElementType(output, type);
206 
207     std::vector<int64> broadcast_dims(input_shape.dims() - 2);
208     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
209     broadcast_dims.back() = channel_dim;
210     output =
211         xla::Add(xla::Mul(input, factor),
212                  xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
213                  broadcast_dims);
214     context->SetOutput(0, output);
215   }
216 };
217 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
218 
219 class AdjustSaturationOp : public XlaOpKernel {
220  public:
AdjustSaturationOp(OpKernelConstruction * context)221   explicit AdjustSaturationOp(OpKernelConstruction* context)
222       : XlaOpKernel(context) {}
223 
Compile(XlaOpKernelContext * context)224   void Compile(XlaOpKernelContext* context) override {
225     const TensorShape& input_shape = context->InputShape(0);
226     const TensorShape& scale_shape = context->InputShape(1);
227     OP_REQUIRES(context, input_shape.dims() >= 3,
228                 errors::InvalidArgument("input must be at least 3-D, got shape",
229                                         input_shape.DebugString()));
230     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
231                 errors::InvalidArgument("scale must be scalar: ",
232                                         scale_shape.DebugString()));
233     const int channel_dim = input_shape.dims() - 1;
234     const int64 channels = input_shape.dim_size(channel_dim);
235     OP_REQUIRES(
236         context, channels == 3,
237         errors::InvalidArgument("input must have 3 channels but instead has ",
238                                 channels, " channels."));
239 
240     xla::XlaBuilder* b = context->builder();
241     xla::XlaOp input =
242         XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
243     xla::XlaOp scale =
244         XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
245 
246     DataType type = context->input_type(0);
247 
248     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
249                                      /*limit_index=*/1, /*stride=*/1,
250                                      /*dimno=*/channel_dim);
251     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
252                                        /*limit_index=*/2, /*stride=*/1,
253                                        /*dimno=*/channel_dim);
254     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
255                                       /*limit_index=*/3, /*stride=*/1,
256                                       /*dimno=*/channel_dim);
257     TensorShape channel_shape = input_shape;
258     channel_shape.set_dim(channel_dim, 1);
259     auto hsv =
260         RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
261 
262     hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale),
263                         XlaHelpers::One(b, DT_FLOAT));
264 
265     auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
266 
267     auto output = XlaHelpers::ConvertElementType(
268         xla::ConcatInDim(b, rgb, channel_dim), type);
269     context->SetOutput(0, output);
270   }
271 };
272 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
273 
274 class AdjustHueOp : public XlaOpKernel {
275  public:
AdjustHueOp(OpKernelConstruction * context)276   explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
277 
Compile(XlaOpKernelContext * context)278   void Compile(XlaOpKernelContext* context) override {
279     const TensorShape& input_shape = context->InputShape(0);
280     const TensorShape& delta_shape = context->InputShape(1);
281     OP_REQUIRES(context, input_shape.dims() >= 3,
282                 errors::InvalidArgument("input must be at least 3-D, got shape",
283                                         input_shape.DebugString()));
284     OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
285                 errors::InvalidArgument("delta must be scalar: ",
286                                         delta_shape.DebugString()));
287     const int channel_dim = input_shape.dims() - 1;
288     const int64 channels = input_shape.dim_size(channel_dim);
289     OP_REQUIRES(
290         context, channels == 3,
291         errors::InvalidArgument("input must have 3 channels but instead has ",
292                                 channels, " channels."));
293 
294     xla::XlaBuilder* b = context->builder();
295     xla::XlaOp input =
296         XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
297     xla::XlaOp delta =
298         XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
299 
300     DataType type = context->input_type(0);
301 
302     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
303                                      /*limit_index=*/1, /*stride=*/1,
304                                      /*dimno=*/channel_dim);
305     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
306                                        /*limit_index=*/2, /*stride=*/1,
307                                        /*dimno=*/channel_dim);
308     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
309                                       /*limit_index=*/3, /*stride=*/1,
310                                       /*dimno=*/channel_dim);
311     TensorShape channel_shape = input_shape;
312     channel_shape.set_dim(channel_dim, 1);
313     auto hsv =
314         RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
315 
316     auto zero = XlaHelpers::Zero(b, DT_FLOAT);
317     auto one = XlaHelpers::One(b, DT_FLOAT);
318 
319     auto& hue = hsv[0];
320     hue = xla::Rem(xla::Add(hsv[0], delta), one);
321     hue =
322         xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
323 
324     auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
325 
326     auto output = XlaHelpers::ConvertElementType(
327         xla::ConcatInDim(b, rgb, channel_dim), type);
328     context->SetOutput(0, output);
329   }
330 };
331 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
332 
333 struct WhileCondFn {
334   const int64 num_boxes;
335   const int64 output_size;
336 
WhileCondFntensorflow::__anonca00298a0111::WhileCondFn337   explicit WhileCondFn(int64 num_boxes, int64 output_size)
338       : num_boxes(num_boxes), output_size(output_size) {}
339 
operator ()tensorflow::__anonca00298a0111::WhileCondFn340   xla::StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values,
341                                        xla::XlaBuilder* cond_builder) const {
342     xla::XlaOp row_idx = values[0];
343     xla::XlaOp row_in_bounds =
344         xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes));
345     xla::XlaOp num_outputs_so_far = values[1];
346     xla::XlaOp results_not_full = xla::Lt(
347         num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size));
348     return xla::And(row_in_bounds, results_not_full);
349   }
350 };
351 
352 // Process the boxes one-by-one using the iou matrix mask.
353 // This implementation uses a correct, but greedy, sequential algorithm
354 // to ensure that suppressed boxes cannot themselves suppress other
355 // boxes.
356 struct SuppressBodyFn {
357   const int64 num_boxes;
358 
SuppressBodyFntensorflow::__anonca00298a0111::SuppressBodyFn359   explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {}
360 
operator ()tensorflow::__anonca00298a0111::SuppressBodyFn361   xla::StatusOr<std::vector<xla::XlaOp>> operator()(
362       absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const {
363     auto row_idx = values[0];
364     auto num_outputs_so_far = values[1];
365     auto iou_mask = values[2];
366     auto included_iou = values[3];
367     auto zero = xla::ConstantR0<int32>(builder, 0);
368     // Determine if current elem is active using a slice.
369     // TODO(b/118437727): The only reason we need an explicit vector is because
370     // some old GCCs can't deduce the right type for MakeConstSpan, and
371     // providing a single-value initializer list directly uses the wrong
372     // overload. Delete this once the deprecated overload is gone.
373     std::vector<xla::XlaOp> row_idx_vector = {row_idx};
374     auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1});
375     active_elem = xla::Reshape(active_elem, {});
376     // Increment output count iff current elem is not suppressed.
377     num_outputs_so_far = xla::Select(
378         active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1),
379         num_outputs_so_far);
380     // Slice out the row_idx.
381     auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes});
382     // Remove the diagonal from consideration. An elem cannot suppress
383     // itself.
384     row_iou = xla::DynamicUpdateSlice(
385         row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}),
386         {zero, row_idx});
387     // Create a suppression by inverting polarity.
388     row_iou = xla::Reshape(row_iou, {num_boxes});
389     auto supp_mask = xla::Not(row_iou);
390     // Update mask iff current elem is not suppressed.
391     included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}),
392                                xla::And(included_iou, supp_mask), included_iou);
393     row_idx = row_idx + xla::ConstantR0<int32>(builder, 1);
394     return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask,
395                                    included_iou};
396   }
397 };
398 
399 class NonMaxSuppressionOp : public XlaOpKernel {
400  public:
NonMaxSuppressionOp(OpKernelConstruction * context)401   explicit NonMaxSuppressionOp(OpKernelConstruction* context)
402       : XlaOpKernel(context) {
403     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
404                                              &pad_to_max_output_size_));
405   }
406 
Compile(XlaOpKernelContext * context)407   void Compile(XlaOpKernelContext* context) override {
408     // TODO(b/111646731): Improve scalability of this op, using blocking.
409     const TensorShape& boxes_shape = context->InputShape("boxes");
410     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape),
411                 errors::InvalidArgument("boxes must be 2-D, currently: ",
412                                         boxes_shape.DebugString()));
413     const int64 num_boxes = boxes_shape.dim_size(0);
414     OP_REQUIRES(context, boxes_shape.dim_size(1) == 4,
415                 errors::InvalidArgument("boxes must have 4 columns",
416                                         boxes_shape.DebugString()));
417     const TensorShape& scores_shape = context->InputShape("scores");
418     OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
419                 errors::InvalidArgument("scores must be 1-D, currently: ",
420                                         scores_shape.DebugString()));
421     OP_REQUIRES(
422         context, scores_shape.dim_size(0) == num_boxes,
423         errors::InvalidArgument("scores size must equal number of boxes",
424                                 scores_shape.DebugString()));
425     OP_REQUIRES(context, pad_to_max_output_size_,
426                 errors::Unimplemented(
427                     "XLA compilation requires pad_to_max_output_size == True"));
428     OP_REQUIRES(context, num_boxes <= kint32max,
429                 errors::InvalidArgument("XLA compilation requires number of "
430                                         "boxes to be <= kint32max, got ",
431                                         num_boxes));
432     xla::PrimitiveType boxes_xla_type = context->InputXlaType("boxes");
433     xla::PrimitiveType scores_xla_type = context->InputXlaType("scores");
434     const xla::XlaOp boxes_input = context->Input("boxes");
435     const xla::XlaOp scores_input = context->Input("scores");
436     int64 output_size;
437     OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
438     OP_REQUIRES(
439         context, output_size >= 0,
440         errors::InvalidArgument("Need output_size >= 0, got ", output_size));
441     OP_REQUIRES(context, output_size <= kint32max,
442                 errors::InvalidArgument("Need output_size <= kint32Max, got ",
443                                         output_size));
444     const xla::XlaOp score_thresh = context->Input("score_threshold");
445     const xla::XlaOp iou_thresh = context->Input("iou_threshold");
446     xla::XlaBuilder* const builder = context->builder();
447 
448     // Choose a more convenient layout.
449     const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0});
450     const xla::XlaOp boxes_sorted = xla::GetTupleElement(
451         xla::Sort({xla::Broadcast(scores_input, {4}), boxes},
452                   xla::CreateScalarGtComputation(
453                       {scores_xla_type, boxes_xla_type}, builder),
454                   /*dimension=*/1),
455         1);
456     // Track the mapping of indices into sorted domain.
457     const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes);
458     const xla::XlaOp indices_sort = xla::Sort(
459         {scores_input, iota_indices},
460         xla::CreateScalarGtComputation({scores_xla_type, xla::S32}, builder));
461     const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
462     const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0);
463 
464     // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
465     const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
466                                                          /*start_index=*/0,
467                                                          /*limit_index=*/1,
468                                                          /*stride=*/1,
469                                                          /*dimno=*/0),
470                                          {num_boxes});
471     const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
472                                                          /*start_index=*/1,
473                                                          /*limit_index=*/2,
474                                                          /*stride=*/1,
475                                                          /*dimno=*/0),
476                                          {num_boxes});
477     const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
478                                                          /*start_index=*/2,
479                                                          /*limit_index=*/3,
480                                                          /*stride=*/1,
481                                                          /*dimno=*/0),
482                                          {num_boxes});
483     const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
484                                                          /*start_index=*/3,
485                                                          /*limit_index=*/4,
486                                                          /*stride=*/1,
487                                                          /*dimno=*/0),
488                                          {num_boxes});
489 
490     xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
491     xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
492     xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
493     xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
494     xla::XlaOp area = (y2 - y1) * (x2 - x1);
495 
496     // Shapes are henceforth [1, num_boxes].
497     y1 = xla::Broadcast(y1, {1});
498     y2 = xla::Broadcast(y2, {1});
499     x1 = xla::Broadcast(x1, {1});
500     x2 = xla::Broadcast(x2, {1});
501     area = xla::Broadcast(area, {1});
502 
503     // Shapes are henceforth [num_boxes, num_boxes].
504     xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
505     xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
506     xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
507     xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
508     auto square_zero = xla::ZerosLike(i_xmin);
509 
510     xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
511                         xla::Max(i_ymax - i_ymin, square_zero);
512     xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
513     xla::XlaOp iou = i_area / u_area;
514 
515     xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
516     xla::XlaOp included_iou =
517         xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
518 
519     std::vector<xla::XlaOp> init_values;
520     init_values.reserve(4);
521     init_values.push_back(xla::ConstantR0<int32>(builder, 0));  // col_idx
522     init_values.push_back(xla::ConstantR0<int32>(builder, 0));  // num_outputs
523     init_values.push_back(iou_thresh_mask);
524     init_values.push_back(included_iou);
525 
526     auto suppress_loop_result =
527         xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
528                              SuppressBodyFn(num_boxes), init_values,
529                              "suppress_loop", builder)
530             .ValueOrDie();
531 
532     xla::XlaOp included_score =
533         xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
534     xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
535 
536     // Only consider boxes over which we have iterated. This allows for accurate
537     // counting. DynamicSlice would require knowledge of the size of the output.
538     auto valid_elem = xla::Lt(
539         iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
540     included = xla::And(included, valid_elem);
541 
542     xla::XlaOp neg_inf =
543         xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes});
544     xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
545     xla::XlaOp output_tuple = TopK(scores_included, output_size);
546     xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
547     // Calculate num_valid.
548     // Note: num_valid cannot be taken from the loop outputs, because outputs
549     // can be suppressed by score threshold.
550     xla::XlaOp ones_included = xla::Select(
551         included,
552         xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
553         xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
554     // num_valid is scalar. Value should be bound by output_size.
555     xla::XlaOp num_valid_total = xla::Reduce(
556         ones_included,
557         /*init_value=*/xla::ConstantR0<int>(builder, 0),
558         /*computation=*/CreateScalarAddComputation(xla::S32, builder),
559         /*dimensions_to_reduce=*/{0});
560     xla::XlaOp num_valid =
561         xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
562 
563     // Re-index into the original scores input tensor, using a Gather.
564     // Boxes were suppressed in the sorted domain.
565     xla::XlaOp selected_indices;
566     DataType gather_type = context->expected_output_dtype(0);
567     OP_REQUIRES_OK(
568         context,
569         XlaGather(indices_sorted, scores_shape, selected_indices_sorted,
570                   TensorShape({output_size}),
571                   /*axis=*/0,
572                   /*indices_are_nd=*/false,
573                   /*dtype=*/gather_type, DT_INT32, builder, &selected_indices));
574 
575     context->SetOutput(0, selected_indices);
576     context->SetOutput(1, num_valid);
577   }
578 
579  private:
580   bool pad_to_max_output_size_;
581 };
582 
583 REGISTER_XLA_OP(
584     Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"),
585     NonMaxSuppressionOp);
586 
587 }  // namespace
588 }  // namespace tensorflow
589