1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/types/span.h"
17 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
18 #include "tensorflow/compiler/tf2xla/lib/util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/loops.h"
26 #include "tensorflow/compiler/xla/client/lib/sorting.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.pb.h"
32
33 namespace tensorflow {
34 namespace {
35
36 // Converts 'input' from RGB format to HSV format.
37 // 'shape' is the shape of the red/green/blue tensors.
RGBToHSV(XlaOpKernelContext * ctx,xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & rgb,DataType dtype,const TensorShape & shape)38 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
39 const std::array<xla::XlaOp, 3>& rgb,
40 DataType dtype, const TensorShape& shape) {
41 auto zero = XlaHelpers::Zero(b, dtype);
42 auto one = XlaHelpers::One(b, dtype);
43
44 auto red = rgb[0];
45 auto green = rgb[1];
46 auto blue = rgb[2];
47 auto value = xla::Max(xla::Max(red, green), blue);
48 auto minimum = xla::Min(xla::Min(red, green), blue);
49 auto range = xla::Sub(value, minimum);
50
51 auto zeros = xla::Broadcast(zero, shape.dim_sizes());
52 auto saturation =
53 xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
54
55 auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
56
57 auto hue =
58 xla::Select(xla::Eq(green, value),
59 xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
60 XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
61 xla::Add(xla::Mul(norm, xla::Sub(red, green)),
62 XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
63 hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
64 hue);
65 hue = xla::Select(xla::Gt(range, zero), hue, zeros);
66 hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
67 return {hue, saturation, value};
68 }
69
70 // Converts 'input' from HSV format to RGB format.
HSVToRGB(xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & hsv,DataType dtype)71 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
72 const std::array<xla::XlaOp, 3>& hsv,
73 DataType dtype) {
74 xla::XlaOp hue = hsv[0];
75 xla::XlaOp saturation = hsv[1];
76 xla::XlaOp value = hsv[2];
77 auto zero = XlaHelpers::Zero(b, dtype);
78 auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
79 auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
80 auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
81 auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
82 auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
83
84 auto dh = xla::Mul(hue, six);
85 auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
86 auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
87 auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
88 auto one_minus_s = xla::Sub(one, saturation);
89
90 auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
91 auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
92 auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
93 return {red, green, blue};
94 }
95
96 class RGBToHSVOp : public XlaOpKernel {
97 public:
RGBToHSVOp(OpKernelConstruction * context)98 explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
99
Compile(XlaOpKernelContext * context)100 void Compile(XlaOpKernelContext* context) override {
101 const TensorShape input_shape = context->InputShape(0);
102 OP_REQUIRES(context, input_shape.dims() >= 1,
103 errors::InvalidArgument("input must be at least 1D",
104 input_shape.DebugString()));
105 int channel_dim = input_shape.dims() - 1;
106 int64 channels = input_shape.dim_size(channel_dim);
107 OP_REQUIRES(
108 context, channels == 3,
109 errors::FailedPrecondition("input must have 3 channels but input has ",
110 channels, " channels."));
111
112 xla::XlaBuilder* b = context->builder();
113 xla::XlaOp input = context->Input(0);
114
115 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
116 /*limit_index=*/1, /*stride=*/1,
117 /*dimno=*/channel_dim);
118 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
119 /*limit_index=*/2, /*stride=*/1,
120 /*dimno=*/channel_dim);
121 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
122 /*limit_index=*/3, /*stride=*/1,
123 /*dimno=*/channel_dim);
124 TensorShape channel_shape = input_shape;
125 channel_shape.set_dim(channel_dim, 1);
126 auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
127 channel_shape);
128
129 context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
130 }
131 };
132 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
133
134 class HSVToRGBOp : public XlaOpKernel {
135 public:
HSVToRGBOp(OpKernelConstruction * context)136 explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
137
Compile(XlaOpKernelContext * context)138 void Compile(XlaOpKernelContext* context) override {
139 const TensorShape input_shape = context->InputShape(0);
140 OP_REQUIRES(context, input_shape.dims() >= 1,
141 errors::InvalidArgument("input must be at least 1D",
142 input_shape.DebugString()));
143 int channel_dim = input_shape.dims() - 1;
144 int64 channels = input_shape.dim_size(channel_dim);
145 OP_REQUIRES(
146 context, channels == 3,
147 errors::FailedPrecondition("input must have 3 channels but input has ",
148 channels, " channels."));
149
150 xla::XlaBuilder* b = context->builder();
151 xla::XlaOp input = context->Input(0);
152 xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
153 /*limit_index=*/1, /*stride=*/1,
154 /*dimno=*/channel_dim);
155 xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
156 /*limit_index=*/2, /*stride=*/1,
157 /*dimno=*/channel_dim);
158 xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
159 /*limit_index=*/3, /*stride=*/1,
160 /*dimno=*/channel_dim);
161
162 auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
163 context->input_type(0));
164
165 context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
166 }
167 };
168 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
169
170 class AdjustContrastOpV2 : public XlaOpKernel {
171 public:
AdjustContrastOpV2(OpKernelConstruction * context)172 explicit AdjustContrastOpV2(OpKernelConstruction* context)
173 : XlaOpKernel(context) {}
174
Compile(XlaOpKernelContext * context)175 void Compile(XlaOpKernelContext* context) override {
176 const TensorShape& input_shape = context->InputShape(0);
177 const TensorShape& factor_shape = context->InputShape(1);
178 OP_REQUIRES(context, input_shape.dims() >= 3,
179 errors::InvalidArgument("input must be at least 3-D, got shape",
180 input_shape.DebugString()));
181 int height_dim = input_shape.dims() - 3;
182 int width_dim = input_shape.dims() - 2;
183 int channel_dim = input_shape.dims() - 1;
184 const int64 height = input_shape.dim_size(height_dim);
185 const int64 width = input_shape.dim_size(width_dim);
186
187 OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
188 errors::InvalidArgument("contrast_factor must be scalar: ",
189 factor_shape.DebugString()));
190
191 xla::XlaBuilder* b = context->builder();
192 DataType type = context->input_type(0);
193
194 xla::XlaOp input = context->Input(0);
195 xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type);
196
197 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
198 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
199 auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
200 *context->GetOrCreateAdd(accumulation_type),
201 {height_dim, width_dim});
202
203 auto output = xla::Div(
204 reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width));
205 output = XlaHelpers::ConvertElementType(output, type);
206
207 std::vector<int64> broadcast_dims(input_shape.dims() - 2);
208 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
209 broadcast_dims.back() = channel_dim;
210 output =
211 xla::Add(xla::Mul(input, factor),
212 xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
213 broadcast_dims);
214 context->SetOutput(0, output);
215 }
216 };
217 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
218
219 class AdjustSaturationOp : public XlaOpKernel {
220 public:
AdjustSaturationOp(OpKernelConstruction * context)221 explicit AdjustSaturationOp(OpKernelConstruction* context)
222 : XlaOpKernel(context) {}
223
Compile(XlaOpKernelContext * context)224 void Compile(XlaOpKernelContext* context) override {
225 const TensorShape& input_shape = context->InputShape(0);
226 const TensorShape& scale_shape = context->InputShape(1);
227 OP_REQUIRES(context, input_shape.dims() >= 3,
228 errors::InvalidArgument("input must be at least 3-D, got shape",
229 input_shape.DebugString()));
230 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
231 errors::InvalidArgument("scale must be scalar: ",
232 scale_shape.DebugString()));
233 const int channel_dim = input_shape.dims() - 1;
234 const int64 channels = input_shape.dim_size(channel_dim);
235 OP_REQUIRES(
236 context, channels == 3,
237 errors::InvalidArgument("input must have 3 channels but instead has ",
238 channels, " channels."));
239
240 xla::XlaBuilder* b = context->builder();
241 xla::XlaOp input =
242 XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
243 xla::XlaOp scale =
244 XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
245
246 DataType type = context->input_type(0);
247
248 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
249 /*limit_index=*/1, /*stride=*/1,
250 /*dimno=*/channel_dim);
251 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
252 /*limit_index=*/2, /*stride=*/1,
253 /*dimno=*/channel_dim);
254 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
255 /*limit_index=*/3, /*stride=*/1,
256 /*dimno=*/channel_dim);
257 TensorShape channel_shape = input_shape;
258 channel_shape.set_dim(channel_dim, 1);
259 auto hsv =
260 RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
261
262 hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale),
263 XlaHelpers::One(b, DT_FLOAT));
264
265 auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
266
267 auto output = XlaHelpers::ConvertElementType(
268 xla::ConcatInDim(b, rgb, channel_dim), type);
269 context->SetOutput(0, output);
270 }
271 };
272 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
273
274 class AdjustHueOp : public XlaOpKernel {
275 public:
AdjustHueOp(OpKernelConstruction * context)276 explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
277
Compile(XlaOpKernelContext * context)278 void Compile(XlaOpKernelContext* context) override {
279 const TensorShape& input_shape = context->InputShape(0);
280 const TensorShape& delta_shape = context->InputShape(1);
281 OP_REQUIRES(context, input_shape.dims() >= 3,
282 errors::InvalidArgument("input must be at least 3-D, got shape",
283 input_shape.DebugString()));
284 OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
285 errors::InvalidArgument("delta must be scalar: ",
286 delta_shape.DebugString()));
287 const int channel_dim = input_shape.dims() - 1;
288 const int64 channels = input_shape.dim_size(channel_dim);
289 OP_REQUIRES(
290 context, channels == 3,
291 errors::InvalidArgument("input must have 3 channels but instead has ",
292 channels, " channels."));
293
294 xla::XlaBuilder* b = context->builder();
295 xla::XlaOp input =
296 XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
297 xla::XlaOp delta =
298 XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
299
300 DataType type = context->input_type(0);
301
302 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
303 /*limit_index=*/1, /*stride=*/1,
304 /*dimno=*/channel_dim);
305 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
306 /*limit_index=*/2, /*stride=*/1,
307 /*dimno=*/channel_dim);
308 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
309 /*limit_index=*/3, /*stride=*/1,
310 /*dimno=*/channel_dim);
311 TensorShape channel_shape = input_shape;
312 channel_shape.set_dim(channel_dim, 1);
313 auto hsv =
314 RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
315
316 auto zero = XlaHelpers::Zero(b, DT_FLOAT);
317 auto one = XlaHelpers::One(b, DT_FLOAT);
318
319 auto& hue = hsv[0];
320 hue = xla::Rem(xla::Add(hsv[0], delta), one);
321 hue =
322 xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
323
324 auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
325
326 auto output = XlaHelpers::ConvertElementType(
327 xla::ConcatInDim(b, rgb, channel_dim), type);
328 context->SetOutput(0, output);
329 }
330 };
331 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
332
333 struct WhileCondFn {
334 const int64 num_boxes;
335 const int64 output_size;
336
WhileCondFntensorflow::__anonca00298a0111::WhileCondFn337 explicit WhileCondFn(int64 num_boxes, int64 output_size)
338 : num_boxes(num_boxes), output_size(output_size) {}
339
operator ()tensorflow::__anonca00298a0111::WhileCondFn340 xla::StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values,
341 xla::XlaBuilder* cond_builder) const {
342 xla::XlaOp row_idx = values[0];
343 xla::XlaOp row_in_bounds =
344 xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes));
345 xla::XlaOp num_outputs_so_far = values[1];
346 xla::XlaOp results_not_full = xla::Lt(
347 num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size));
348 return xla::And(row_in_bounds, results_not_full);
349 }
350 };
351
352 // Process the boxes one-by-one using the iou matrix mask.
353 // This implementation uses a correct, but greedy, sequential algorithm
354 // to ensure that suppressed boxes cannot themselves suppress other
355 // boxes.
356 struct SuppressBodyFn {
357 const int64 num_boxes;
358
SuppressBodyFntensorflow::__anonca00298a0111::SuppressBodyFn359 explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {}
360
operator ()tensorflow::__anonca00298a0111::SuppressBodyFn361 xla::StatusOr<std::vector<xla::XlaOp>> operator()(
362 absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const {
363 auto row_idx = values[0];
364 auto num_outputs_so_far = values[1];
365 auto iou_mask = values[2];
366 auto included_iou = values[3];
367 auto zero = xla::ConstantR0<int32>(builder, 0);
368 // Determine if current elem is active using a slice.
369 // TODO(b/118437727): The only reason we need an explicit vector is because
370 // some old GCCs can't deduce the right type for MakeConstSpan, and
371 // providing a single-value initializer list directly uses the wrong
372 // overload. Delete this once the deprecated overload is gone.
373 std::vector<xla::XlaOp> row_idx_vector = {row_idx};
374 auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1});
375 active_elem = xla::Reshape(active_elem, {});
376 // Increment output count iff current elem is not suppressed.
377 num_outputs_so_far = xla::Select(
378 active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1),
379 num_outputs_so_far);
380 // Slice out the row_idx.
381 auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes});
382 // Remove the diagonal from consideration. An elem cannot suppress
383 // itself.
384 row_iou = xla::DynamicUpdateSlice(
385 row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}),
386 {zero, row_idx});
387 // Create a suppression by inverting polarity.
388 row_iou = xla::Reshape(row_iou, {num_boxes});
389 auto supp_mask = xla::Not(row_iou);
390 // Update mask iff current elem is not suppressed.
391 included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}),
392 xla::And(included_iou, supp_mask), included_iou);
393 row_idx = row_idx + xla::ConstantR0<int32>(builder, 1);
394 return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask,
395 included_iou};
396 }
397 };
398
399 class NonMaxSuppressionOp : public XlaOpKernel {
400 public:
NonMaxSuppressionOp(OpKernelConstruction * context)401 explicit NonMaxSuppressionOp(OpKernelConstruction* context)
402 : XlaOpKernel(context) {
403 OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
404 &pad_to_max_output_size_));
405 }
406
Compile(XlaOpKernelContext * context)407 void Compile(XlaOpKernelContext* context) override {
408 // TODO(b/111646731): Improve scalability of this op, using blocking.
409 const TensorShape& boxes_shape = context->InputShape("boxes");
410 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape),
411 errors::InvalidArgument("boxes must be 2-D, currently: ",
412 boxes_shape.DebugString()));
413 const int64 num_boxes = boxes_shape.dim_size(0);
414 OP_REQUIRES(context, boxes_shape.dim_size(1) == 4,
415 errors::InvalidArgument("boxes must have 4 columns",
416 boxes_shape.DebugString()));
417 const TensorShape& scores_shape = context->InputShape("scores");
418 OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
419 errors::InvalidArgument("scores must be 1-D, currently: ",
420 scores_shape.DebugString()));
421 OP_REQUIRES(
422 context, scores_shape.dim_size(0) == num_boxes,
423 errors::InvalidArgument("scores size must equal number of boxes",
424 scores_shape.DebugString()));
425 OP_REQUIRES(context, pad_to_max_output_size_,
426 errors::Unimplemented(
427 "XLA compilation requires pad_to_max_output_size == True"));
428 OP_REQUIRES(context, num_boxes <= kint32max,
429 errors::InvalidArgument("XLA compilation requires number of "
430 "boxes to be <= kint32max, got ",
431 num_boxes));
432 xla::PrimitiveType boxes_xla_type = context->InputXlaType("boxes");
433 xla::PrimitiveType scores_xla_type = context->InputXlaType("scores");
434 const xla::XlaOp boxes_input = context->Input("boxes");
435 const xla::XlaOp scores_input = context->Input("scores");
436 int64 output_size;
437 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
438 OP_REQUIRES(
439 context, output_size >= 0,
440 errors::InvalidArgument("Need output_size >= 0, got ", output_size));
441 OP_REQUIRES(context, output_size <= kint32max,
442 errors::InvalidArgument("Need output_size <= kint32Max, got ",
443 output_size));
444 const xla::XlaOp score_thresh = context->Input("score_threshold");
445 const xla::XlaOp iou_thresh = context->Input("iou_threshold");
446 xla::XlaBuilder* const builder = context->builder();
447
448 // Choose a more convenient layout.
449 const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0});
450 const xla::XlaOp boxes_sorted = xla::GetTupleElement(
451 xla::Sort({xla::Broadcast(scores_input, {4}), boxes},
452 xla::CreateScalarGtComputation(
453 {scores_xla_type, boxes_xla_type}, builder),
454 /*dimension=*/1),
455 1);
456 // Track the mapping of indices into sorted domain.
457 const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes);
458 const xla::XlaOp indices_sort = xla::Sort(
459 {scores_input, iota_indices},
460 xla::CreateScalarGtComputation({scores_xla_type, xla::S32}, builder));
461 const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
462 const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0);
463
464 // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
465 const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
466 /*start_index=*/0,
467 /*limit_index=*/1,
468 /*stride=*/1,
469 /*dimno=*/0),
470 {num_boxes});
471 const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
472 /*start_index=*/1,
473 /*limit_index=*/2,
474 /*stride=*/1,
475 /*dimno=*/0),
476 {num_boxes});
477 const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
478 /*start_index=*/2,
479 /*limit_index=*/3,
480 /*stride=*/1,
481 /*dimno=*/0),
482 {num_boxes});
483 const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
484 /*start_index=*/3,
485 /*limit_index=*/4,
486 /*stride=*/1,
487 /*dimno=*/0),
488 {num_boxes});
489
490 xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
491 xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
492 xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
493 xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
494 xla::XlaOp area = (y2 - y1) * (x2 - x1);
495
496 // Shapes are henceforth [1, num_boxes].
497 y1 = xla::Broadcast(y1, {1});
498 y2 = xla::Broadcast(y2, {1});
499 x1 = xla::Broadcast(x1, {1});
500 x2 = xla::Broadcast(x2, {1});
501 area = xla::Broadcast(area, {1});
502
503 // Shapes are henceforth [num_boxes, num_boxes].
504 xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
505 xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
506 xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
507 xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
508 auto square_zero = xla::ZerosLike(i_xmin);
509
510 xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
511 xla::Max(i_ymax - i_ymin, square_zero);
512 xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
513 xla::XlaOp iou = i_area / u_area;
514
515 xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
516 xla::XlaOp included_iou =
517 xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
518
519 std::vector<xla::XlaOp> init_values;
520 init_values.reserve(4);
521 init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // col_idx
522 init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // num_outputs
523 init_values.push_back(iou_thresh_mask);
524 init_values.push_back(included_iou);
525
526 auto suppress_loop_result =
527 xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
528 SuppressBodyFn(num_boxes), init_values,
529 "suppress_loop", builder)
530 .ValueOrDie();
531
532 xla::XlaOp included_score =
533 xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
534 xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
535
536 // Only consider boxes over which we have iterated. This allows for accurate
537 // counting. DynamicSlice would require knowledge of the size of the output.
538 auto valid_elem = xla::Lt(
539 iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
540 included = xla::And(included, valid_elem);
541
542 xla::XlaOp neg_inf =
543 xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes});
544 xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
545 xla::XlaOp output_tuple = TopK(scores_included, output_size);
546 xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
547 // Calculate num_valid.
548 // Note: num_valid cannot be taken from the loop outputs, because outputs
549 // can be suppressed by score threshold.
550 xla::XlaOp ones_included = xla::Select(
551 included,
552 xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
553 xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
554 // num_valid is scalar. Value should be bound by output_size.
555 xla::XlaOp num_valid_total = xla::Reduce(
556 ones_included,
557 /*init_value=*/xla::ConstantR0<int>(builder, 0),
558 /*computation=*/CreateScalarAddComputation(xla::S32, builder),
559 /*dimensions_to_reduce=*/{0});
560 xla::XlaOp num_valid =
561 xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
562
563 // Re-index into the original scores input tensor, using a Gather.
564 // Boxes were suppressed in the sorted domain.
565 xla::XlaOp selected_indices;
566 DataType gather_type = context->expected_output_dtype(0);
567 OP_REQUIRES_OK(
568 context,
569 XlaGather(indices_sorted, scores_shape, selected_indices_sorted,
570 TensorShape({output_size}),
571 /*axis=*/0,
572 /*indices_are_nd=*/false,
573 /*dtype=*/gather_type, DT_INT32, builder, &selected_indices));
574
575 context->SetOutput(0, selected_indices);
576 context->SetOutput(1, num_valid);
577 }
578
579 private:
580 bool pad_to_max_output_size_;
581 };
582
583 REGISTER_XLA_OP(
584 Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"),
585 NonMaxSuppressionOp);
586
587 } // namespace
588 } // namespace tensorflow
589