1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25
26 namespace {
27
28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where
29 // height and width come from the size_tensor.
SetOutputToSizedImage(InferenceContext * c,DimensionHandle batch_dim,int size_input_idx,DimensionHandle channel_dim)30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
31 int size_input_idx, DimensionHandle channel_dim) {
32 // Verify shape of size input.
33 ShapeHandle size;
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
35 DimensionHandle unused;
36 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
37
38 // Get size values from the size tensor.
39 const Tensor* size_tensor = c->input_tensor(size_input_idx);
40 DimensionHandle width;
41 DimensionHandle height;
42 if (size_tensor == nullptr) {
43 width = c->UnknownDim();
44 height = c->UnknownDim();
45 } else {
46 // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
47 if (size_tensor->dtype() != DT_INT32) {
48 return errors::InvalidArgument(
49 "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
50 "but got ",
51 DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
52 " in ", c->DebugString());
53 }
54 auto vec = size_tensor->vec<int32>();
55 height = c->MakeDim(vec(0));
56 width = c->MakeDim(vec(1));
57 }
58 c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
59 return Status::OK();
60 }
61
ResizeShapeFn(InferenceContext * c)62 Status ResizeShapeFn(InferenceContext* c) {
63 ShapeHandle input;
64 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
65 return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */,
66 c->Dim(input, 3));
67 }
68
DecodeImageShapeFn(InferenceContext * c)69 Status DecodeImageShapeFn(InferenceContext* c) {
70 ShapeHandle unused;
71 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
72 DimensionHandle channels_dim;
73 int32 channels;
74 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
75 if (channels == 0) {
76 channels_dim = c->UnknownDim();
77 } else {
78 if (channels < 0) {
79 return errors::InvalidArgument("channels must be non-negative, got ",
80 channels);
81 }
82 channels_dim = c->MakeDim(channels);
83 }
84
85 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
86 InferenceContext::kUnknownDim, channels_dim}));
87 return Status::OK();
88 }
89
EncodeImageShapeFn(InferenceContext * c)90 Status EncodeImageShapeFn(InferenceContext* c) {
91 ShapeHandle unused;
92 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
93 c->set_output(0, c->Scalar());
94 return Status::OK();
95 }
96
ColorspaceShapeFn(InferenceContext * c)97 Status ColorspaceShapeFn(InferenceContext* c) {
98 ShapeHandle input;
99 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
100
101 // The last dimension value is always 3.
102 DimensionHandle last_dim;
103 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim));
104 ShapeHandle out;
105 TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out));
106 c->set_output(0, out);
107
108 return Status::OK();
109 }
110
111 } // namespace
112
113 // --------------------------------------------------------------------------
114 REGISTER_OP("ResizeArea")
115 .Input("images: T")
116 .Input("size: int32")
117 .Output("resized_images: float")
118 .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
119 .Attr("align_corners: bool = false")
120 .SetShapeFn(ResizeShapeFn);
121
122 // --------------------------------------------------------------------------
123 REGISTER_OP("ResizeBicubic")
124 .Input("images: T")
125 .Input("size: int32")
126 .Output("resized_images: float")
127 .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
128 .Attr("align_corners: bool = false")
129 .SetShapeFn(ResizeShapeFn);
130
131 // --------------------------------------------------------------------------
132 REGISTER_OP("ResizeBicubicGrad")
133 .Input("grads: float")
134 .Input("original_image: T")
135 .Output("output: T")
136 .Attr("T: {float, double}")
137 .Attr("align_corners: bool = false")
__anonc970e8b60202(InferenceContext* c) 138 .SetShapeFn([](InferenceContext* c) {
139 c->set_output(0, c->input(1));
140 return Status::OK();
141 });
142
143 // --------------------------------------------------------------------------
144 REGISTER_OP("ResizeBilinear")
145 .Input("images: T")
146 .Input("size: int32")
147 .Output("resized_images: float")
148 .Attr(
149 "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
150 "float, double}")
151 .Attr("align_corners: bool = false")
152 .SetShapeFn(ResizeShapeFn);
153
154 // --------------------------------------------------------------------------
155 REGISTER_OP("QuantizedResizeBilinear")
156 .Input("images: T")
157 .Input("size: int32")
158 .Input("min: float")
159 .Input("max: float")
160 .Output("resized_images: T")
161 .Output("out_min: float")
162 .Output("out_max: float")
163 .Attr("T: {quint8, qint32, float}")
164 .Attr("align_corners: bool = false")
__anonc970e8b60302(InferenceContext* c) 165 .SetShapeFn([](InferenceContext* c) {
166 TF_RETURN_IF_ERROR(ResizeShapeFn(c));
167 ShapeHandle min_shape;
168 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_shape));
169 ShapeHandle max_shape;
170 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_shape));
171 c->set_output(1, c->MakeShape({}));
172 c->set_output(2, c->MakeShape({}));
173 return Status::OK();
174 });
175
176 // --------------------------------------------------------------------------
177 REGISTER_OP("ResizeBilinearGrad")
178 .Input("grads: float")
179 .Input("original_image: T")
180 .Output("output: T")
181 .Attr("T: {float, bfloat16, half, double}")
182 .Attr("align_corners: bool = false")
__anonc970e8b60402(InferenceContext* c) 183 .SetShapeFn([](InferenceContext* c) {
184 c->set_output(0, c->input(1));
185 return Status::OK();
186 });
187
188 // --------------------------------------------------------------------------
189 REGISTER_OP("ResizeNearestNeighbor")
190 .Input("images: T")
191 .Input("size: int32")
192 .Output("resized_images: T")
193 .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
194 .Attr("align_corners: bool = false")
195 .SetShapeFn(ResizeShapeFn);
196
197 // --------------------------------------------------------------------------
198 REGISTER_OP("ResizeNearestNeighborGrad")
199 .Input("grads: T")
200 .Input("size: int32")
201 .Output("output: T")
202 .Attr("T: {uint8, int8, int32, half, float, double}")
203 .Attr("align_corners: bool = false")
__anonc970e8b60502(InferenceContext* c) 204 .SetShapeFn([](InferenceContext* c) {
205 ShapeHandle input;
206 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
207 ShapeHandle unused;
208 DimensionHandle unused_dim;
209 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
210 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim));
211 const Tensor* size = c->input_tensor(1);
212 if (size == nullptr) {
213 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 1, c->UnknownDim(), &input));
214 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 2, c->UnknownDim(), &input));
215 } else {
216 auto size_vec = size->vec<int32>();
217 TF_RETURN_IF_ERROR(
218 c->ReplaceDim(input, 1, c->MakeDim(size_vec(0)), &input));
219 TF_RETURN_IF_ERROR(
220 c->ReplaceDim(input, 2, c->MakeDim(size_vec(1)), &input));
221 }
222 c->set_output(0, input);
223 return Status::OK();
224 });
225
226 // --------------------------------------------------------------------------
227 REGISTER_OP("RandomCrop")
228 .Input("image: T")
229 .Input("size: int64")
230 .Output("output: T")
231 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
232 .Attr("seed: int = 0")
233 .Attr("seed2: int = 0")
234 .SetIsStateful()
235 .Deprecated(8, "Random crop is now pure Python")
__anonc970e8b60602(InferenceContext* c) 236 .SetShapeFn([](InferenceContext* c) {
237 ShapeHandle image;
238 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &image));
239 DimensionHandle channels = c->Dim(image, -1);
240
241 ShapeHandle unused;
242 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused));
243
244 const Tensor* size = c->input_tensor(1);
245 DimensionHandle h;
246 DimensionHandle w;
247 if (size == nullptr) {
248 h = c->UnknownDim();
249 w = c->UnknownDim();
250 } else {
251 auto size_vec = size->vec<int64>();
252 h = c->MakeDim(size_vec(0));
253 w = c->MakeDim(size_vec(1));
254 }
255 c->set_output(0, c->MakeShape({h, w, channels}));
256 return Status::OK();
257 });
258 // TODO(shlens): Support variable rank in RandomCrop.
259
260 // --------------------------------------------------------------------------
261 REGISTER_OP("DecodeJpeg")
262 .Input("contents: string")
263 .Attr("channels: int = 0")
264 .Attr("ratio: int = 1")
265 .Attr("fancy_upscaling: bool = true")
266 .Attr("try_recover_truncated: bool = false")
267 .Attr("acceptable_fraction: float = 1.0")
268 .Attr("dct_method: string = ''")
269 .Output("image: uint8")
270 .SetShapeFn(DecodeImageShapeFn);
271
272 // --------------------------------------------------------------------------
273 REGISTER_OP("DecodeAndCropJpeg")
274 .Input("contents: string")
275 .Input("crop_window: int32")
276 .Attr("channels: int = 0")
277 .Attr("ratio: int = 1")
278 .Attr("fancy_upscaling: bool = true")
279 .Attr("try_recover_truncated: bool = false")
280 .Attr("acceptable_fraction: float = 1.0")
281 .Attr("dct_method: string = ''")
282 .Output("image: uint8")
__anonc970e8b60702(InferenceContext* c) 283 .SetShapeFn([](InferenceContext* c) {
284 ShapeHandle unused;
285 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
286 DimensionHandle channels_dim = c->UnknownDim();
287 DimensionHandle h = c->UnknownDim();
288 DimensionHandle w = c->UnknownDim();
289
290 int32 channels;
291 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
292 if (channels != 0) {
293 if (channels < 0) {
294 return errors::InvalidArgument("channels must be non-negative, got ",
295 channels);
296 }
297 channels_dim = c->MakeDim(channels);
298 }
299
300 DimensionHandle unused_dim;
301 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
302 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim));
303
304 const Tensor* crop_window = c->input_tensor(1);
305 if (crop_window != nullptr) {
306 auto crop_window_vec = crop_window->vec<int32>();
307 h = c->MakeDim(crop_window_vec(2));
308 w = c->MakeDim(crop_window_vec(3));
309 }
310 c->set_output(0, c->MakeShape({h, w, channels_dim}));
311 return Status::OK();
312 });
313
314 // --------------------------------------------------------------------------
315 REGISTER_OP("EncodeJpeg")
316 .Input("image: uint8")
317 .Attr("format: {'', 'grayscale', 'rgb'} = ''")
318 .Attr("quality: int = 95")
319 .Attr("progressive: bool = false")
320 .Attr("optimize_size: bool = false")
321 .Attr("chroma_downsampling: bool = true")
322 .Attr("density_unit: {'in', 'cm'} = 'in'")
323 .Attr("x_density: int = 300")
324 .Attr("y_density: int = 300")
325 .Attr("xmp_metadata: string = ''")
326 .Output("contents: string")
327 .SetShapeFn(EncodeImageShapeFn);
328
329 // --------------------------------------------------------------------------
330 REGISTER_OP("ExtractJpegShape")
331 .Input("contents: string")
332 .Output("image_shape: output_type")
333 .Attr("output_type: {int32, int64} = DT_INT32")
__anonc970e8b60802(InferenceContext* c) 334 .SetShapeFn([](InferenceContext* c) {
335 ShapeHandle unused;
336 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
337 c->set_output(0, c->Vector(3));
338 return Status::OK();
339 });
340
341 // --------------------------------------------------------------------------
342 REGISTER_OP("AdjustContrast")
343 .Input("images: T")
344 .Input("contrast_factor: float")
345 .Input("min_value: float")
346 .Input("max_value: float")
347 .Output("output: float")
348 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
349 .Deprecated(2, "Use AdjustContrastv2 instead")
__anonc970e8b60902(InferenceContext* c) 350 .SetShapeFn([](InferenceContext* c) {
351 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
352 });
353
354 // --------------------------------------------------------------------------
355 REGISTER_OP("AdjustContrastv2")
356 .Input("images: float")
357 .Input("contrast_factor: float")
358 .Output("output: float")
__anonc970e8b60a02(InferenceContext* c) 359 .SetShapeFn([](InferenceContext* c) {
360 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
361 });
362
363 // --------------------------------------------------------------------------
364 REGISTER_OP("AdjustHue")
365 .Input("images: float")
366 .Input("delta: float")
367 .Output("output: float")
__anonc970e8b60b02(InferenceContext* c) 368 .SetShapeFn([](InferenceContext* c) {
369 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
370 });
371
372 // --------------------------------------------------------------------------
373 REGISTER_OP("AdjustSaturation")
374 .Input("images: float")
375 .Input("scale: float")
376 .Output("output: float")
__anonc970e8b60c02(InferenceContext* c) 377 .SetShapeFn([](InferenceContext* c) {
378 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
379 });
380
381 // --------------------------------------------------------------------------
382 REGISTER_OP("DecodePng")
383 .Input("contents: string")
384 .Attr("channels: int = 0")
385 .Attr("dtype: {uint8, uint16} = DT_UINT8")
386 .Output("image: dtype")
387 .SetShapeFn(DecodeImageShapeFn);
388
389 // --------------------------------------------------------------------------
390 REGISTER_OP("EncodePng")
391 .Attr("compression: int = -1")
392 .Attr("T: {uint8, uint16} = DT_UINT8")
393 .Input("image: T")
394 .Output("contents: string")
395 .SetShapeFn(EncodeImageShapeFn);
396
397 // --------------------------------------------------------------------------
398 REGISTER_OP("DecodeBmp")
399 .Input("contents: string")
400 .Output("image: uint8")
401 .Attr("channels: int = 0")
402 .SetShapeFn(DecodeImageShapeFn);
403
404 // --------------------------------------------------------------------------
405 REGISTER_OP("DecodeGif")
406 .Input("contents: string")
407 .Output("image: uint8")
__anonc970e8b60d02(InferenceContext* c) 408 .SetShapeFn([](InferenceContext* c) {
409 ShapeHandle unused;
410 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
411 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
412 InferenceContext::kUnknownDim,
413 InferenceContext::kUnknownDim, 3}));
414 return Status::OK();
415 });
416
417 // --------------------------------------------------------------------------
418 REGISTER_OP("RGBToHSV")
419 .Input("images: T")
420 .Output("output: T")
421 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
422 .SetShapeFn(ColorspaceShapeFn);
423
424 // --------------------------------------------------------------------------
425 REGISTER_OP("HSVToRGB")
426 .Input("images: T")
427 .Output("output: T")
428 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
429 .SetShapeFn(ColorspaceShapeFn);
430
431 // --------------------------------------------------------------------------
432 REGISTER_OP("DrawBoundingBoxes")
433 .Input("images: T")
434 .Input("boxes: float")
435 .Output("output: T")
436 .Attr("T: {float, half} = DT_FLOAT")
__anonc970e8b60e02(InferenceContext* c) 437 .SetShapeFn([](InferenceContext* c) {
438 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
439 });
440
441 // --------------------------------------------------------------------------
442 REGISTER_OP("SampleDistortedBoundingBox")
443 .Input("image_size: T")
444 .Input("bounding_boxes: float")
445 .Output("begin: T")
446 .Output("size: T")
447 .Output("bboxes: float")
448 .Attr("T: {uint8, int8, int16, int32, int64}")
449 .Attr("seed: int = 0")
450 .Attr("seed2: int = 0")
451 .Attr("min_object_covered: float = 0.1")
452 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
453 .Attr("area_range: list(float) = [0.05, 1.0]")
454 .Attr("max_attempts: int = 100")
455 .Attr("use_image_if_no_bounding_boxes: bool = false")
456 .SetIsStateful()
__anonc970e8b60f02(InferenceContext* c) 457 .SetShapeFn([](InferenceContext* c) {
458 // Get inputs and validate ranks.
459 ShapeHandle image_size;
460 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
461 ShapeHandle bounding_boxes;
462 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
463 // image_size: 1-D with [height, width, channels]
464 // bounding_boxes: 3-D with shape [batch, N, 4]
465 DimensionHandle unused;
466 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
467 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
468
469 c->set_output(0, c->Vector(3));
470 c->set_output(1, c->Vector(3));
471 c->set_output(2, c->MakeShape({1, 1, 4}));
472 return Status::OK();
473 });
474
475 REGISTER_OP("SampleDistortedBoundingBoxV2")
476 .Input("image_size: T")
477 .Input("bounding_boxes: float")
478 .Input("min_object_covered: float")
479 .Output("begin: T")
480 .Output("size: T")
481 .Output("bboxes: float")
482 .Attr("T: {uint8, int8, int16, int32, int64}")
483 .Attr("seed: int = 0")
484 .Attr("seed2: int = 0")
485 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
486 .Attr("area_range: list(float) = [0.05, 1.0]")
487 .Attr("max_attempts: int = 100")
488 .Attr("use_image_if_no_bounding_boxes: bool = false")
489 .SetIsStateful()
__anonc970e8b61002(InferenceContext* c) 490 .SetShapeFn([](InferenceContext* c) {
491 // Get inputs and validate ranks.
492 ShapeHandle image_size;
493 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
494 ShapeHandle bounding_boxes;
495 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
496 ShapeHandle min_object_covered;
497 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
498 // image_size: 1-D with [height, width, channels]
499 // bounding_boxes: 3-D with shape [batch, N, 4]
500 DimensionHandle unused;
501 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
502 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
503
504 c->set_output(0, c->Vector(3));
505 c->set_output(1, c->Vector(3));
506 c->set_output(2, c->MakeShape({1, 1, 4}));
507 return Status::OK();
508 });
509
510 // --------------------------------------------------------------------------
511
512 // glimpse = extract_glimpse(input, size, offsets) extract the glimpse
513 // of size `size` centered at location `offsets` from the input tensor
514 // `input`.
515 //
516 // REQUIRES: input.dims() == 4
517 //
518 REGISTER_OP("ExtractGlimpse")
519 .Input("input: float")
520 .Input("size: int32")
521 .Input("offsets: float")
522 .Output("glimpse: float")
523 .Attr("centered: bool = true")
524 .Attr("normalized: bool = true")
525 .Attr("uniform_noise: bool = true")
__anonc970e8b61102(InferenceContext* c) 526 .SetShapeFn([](InferenceContext* c) {
527 ShapeHandle input;
528 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
529 ShapeHandle offsets;
530 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
531
532 DimensionHandle batch_dim;
533 TF_RETURN_IF_ERROR(
534 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
535 DimensionHandle unused;
536 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
537
538 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
539 c->Dim(input, 3));
540 });
541
542 // --------------------------------------------------------------------------
543
544 REGISTER_OP("CropAndResize")
545 .Input("image: T")
546 .Input("boxes: float")
547 .Input("box_ind: int32")
548 .Input("crop_size: int32")
549 .Output("crops: float")
550 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
551 .Attr("method: {'bilinear'} = 'bilinear'")
552 .Attr("extrapolation_value: float = 0")
__anonc970e8b61202(InferenceContext* c) 553 .SetShapeFn([](InferenceContext* c) {
554 // Get inputs and validate ranks.
555 ShapeHandle input;
556 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
557 ShapeHandle boxes;
558 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes));
559 ShapeHandle box_ind;
560 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind));
561
562 // boxes[0] and box_ind[0] are both num_boxes.
563 DimensionHandle num_boxes_dim;
564 TF_RETURN_IF_ERROR(
565 c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim));
566
567 // boxes.dim(1) is 4.
568 DimensionHandle unused;
569 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
570
571 return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */,
572 c->Dim(input, 3));
573 });
574
575 REGISTER_OP("CropAndResizeGradImage")
576 .Input("grads: float")
577 .Input("boxes: float")
578 .Input("box_ind: int32")
579 .Input("image_size: int32")
580 .Output("output: T")
581 .Attr("T: {float, half, double}")
582 .Attr("method: {'bilinear'} = 'bilinear'")
__anonc970e8b61302(InferenceContext* c) 583 .SetShapeFn([](InferenceContext* c) {
584 ShapeHandle out;
585 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
586 TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out));
587 c->set_output(0, out);
588 return Status::OK();
589 });
590
591 REGISTER_OP("CropAndResizeGradBoxes")
592 .Input("grads: float")
593 .Input("image: T")
594 .Input("boxes: float")
595 .Input("box_ind: int32")
596 .Output("output: float")
597 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
598 .Attr("method: {'bilinear'} = 'bilinear'")
__anonc970e8b61402(InferenceContext* c) 599 .SetShapeFn([](InferenceContext* c) {
600 c->set_output(0, c->input(2));
601 return Status::OK();
602 });
603
604 // --------------------------------------------------------------------------
605
606 REGISTER_OP("NonMaxSuppression")
607 .Input("boxes: float")
608 .Input("scores: float")
609 .Input("max_output_size: int32")
610 .Output("selected_indices: int32")
611 .Attr("iou_threshold: float = 0.5")
__anonc970e8b61502(InferenceContext* c) 612 .SetShapeFn([](InferenceContext* c) {
613 // Get inputs and validate ranks.
614 ShapeHandle boxes;
615 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
616 ShapeHandle scores;
617 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
618 ShapeHandle max_output_size;
619 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
620 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
621 DimensionHandle unused;
622 // The boxes[0] and scores[0] are both num_boxes.
623 TF_RETURN_IF_ERROR(
624 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
625 // The boxes[1] is 4.
626 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
627
628 c->set_output(0, c->Vector(c->UnknownDim()));
629 return Status::OK();
630 });
631
632 REGISTER_OP("NonMaxSuppressionV2")
633 .Input("boxes: float")
634 .Input("scores: float")
635 .Input("max_output_size: int32")
636 .Input("iou_threshold: float")
637 .Output("selected_indices: int32")
__anonc970e8b61602(InferenceContext* c) 638 .SetShapeFn([](InferenceContext* c) {
639 // Get inputs and validate ranks.
640 ShapeHandle boxes;
641 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
642 ShapeHandle scores;
643 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
644 ShapeHandle max_output_size;
645 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
646 ShapeHandle iou_threshold;
647 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
648 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
649 DimensionHandle unused;
650 // The boxes[0] and scores[0] are both num_boxes.
651 TF_RETURN_IF_ERROR(
652 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
653 // The boxes[1] is 4.
654 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
655
656 c->set_output(0, c->Vector(c->UnknownDim()));
657 return Status::OK();
658 });
659
660 } // namespace tensorflow
661