1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 namespace {
27 
28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where
29 // height and width come from the size_tensor.
SetOutputToSizedImage(InferenceContext * c,DimensionHandle batch_dim,int size_input_idx,DimensionHandle channel_dim)30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
31                              int size_input_idx, DimensionHandle channel_dim) {
32   // Verify shape of size input.
33   ShapeHandle size;
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
35   DimensionHandle unused;
36   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
37 
38   // Get size values from the size tensor.
39   const Tensor* size_tensor = c->input_tensor(size_input_idx);
40   DimensionHandle width;
41   DimensionHandle height;
42   if (size_tensor == nullptr) {
43     width = c->UnknownDim();
44     height = c->UnknownDim();
45   } else {
46     // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
47     if (size_tensor->dtype() != DT_INT32) {
48       return errors::InvalidArgument(
49           "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
50           "but got ",
51           DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
52           " in ", c->DebugString());
53     }
54     auto vec = size_tensor->vec<int32>();
55     height = c->MakeDim(vec(0));
56     width = c->MakeDim(vec(1));
57   }
58   c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
59   return Status::OK();
60 }
61 
ResizeShapeFn(InferenceContext * c)62 Status ResizeShapeFn(InferenceContext* c) {
63   ShapeHandle input;
64   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
65   return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */,
66                                c->Dim(input, 3));
67 }
68 
DecodeImageShapeFn(InferenceContext * c)69 Status DecodeImageShapeFn(InferenceContext* c) {
70   ShapeHandle unused;
71   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
72   DimensionHandle channels_dim;
73   int32 channels;
74   TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
75   if (channels == 0) {
76     channels_dim = c->UnknownDim();
77   } else {
78     if (channels < 0) {
79       return errors::InvalidArgument("channels must be non-negative, got ",
80                                      channels);
81     }
82     channels_dim = c->MakeDim(channels);
83   }
84 
85   c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
86                                  InferenceContext::kUnknownDim, channels_dim}));
87   return Status::OK();
88 }
89 
EncodeImageShapeFn(InferenceContext * c)90 Status EncodeImageShapeFn(InferenceContext* c) {
91   ShapeHandle unused;
92   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
93   c->set_output(0, c->Scalar());
94   return Status::OK();
95 }
96 
ColorspaceShapeFn(InferenceContext * c)97 Status ColorspaceShapeFn(InferenceContext* c) {
98   ShapeHandle input;
99   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
100 
101   // The last dimension value is always 3.
102   DimensionHandle last_dim;
103   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim));
104   ShapeHandle out;
105   TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out));
106   c->set_output(0, out);
107 
108   return Status::OK();
109 }
110 
111 }  // namespace
112 
113 // --------------------------------------------------------------------------
114 REGISTER_OP("ResizeArea")
115     .Input("images: T")
116     .Input("size: int32")
117     .Output("resized_images: float")
118     .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
119     .Attr("align_corners: bool = false")
120     .SetShapeFn(ResizeShapeFn);
121 
122 // --------------------------------------------------------------------------
123 REGISTER_OP("ResizeBicubic")
124     .Input("images: T")
125     .Input("size: int32")
126     .Output("resized_images: float")
127     .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
128     .Attr("align_corners: bool = false")
129     .SetShapeFn(ResizeShapeFn);
130 
131 // --------------------------------------------------------------------------
132 REGISTER_OP("ResizeBicubicGrad")
133     .Input("grads: float")
134     .Input("original_image: T")
135     .Output("output: T")
136     .Attr("T: {float, double}")
137     .Attr("align_corners: bool = false")
__anonc970e8b60202(InferenceContext* c) 138     .SetShapeFn([](InferenceContext* c) {
139       c->set_output(0, c->input(1));
140       return Status::OK();
141     });
142 
143 // --------------------------------------------------------------------------
144 REGISTER_OP("ResizeBilinear")
145     .Input("images: T")
146     .Input("size: int32")
147     .Output("resized_images: float")
148     .Attr(
149         "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
150         "float, double}")
151     .Attr("align_corners: bool = false")
152     .SetShapeFn(ResizeShapeFn);
153 
154 // --------------------------------------------------------------------------
155 REGISTER_OP("QuantizedResizeBilinear")
156     .Input("images: T")
157     .Input("size: int32")
158     .Input("min: float")
159     .Input("max: float")
160     .Output("resized_images: T")
161     .Output("out_min: float")
162     .Output("out_max: float")
163     .Attr("T: {quint8, qint32, float}")
164     .Attr("align_corners: bool = false")
__anonc970e8b60302(InferenceContext* c) 165     .SetShapeFn([](InferenceContext* c) {
166       TF_RETURN_IF_ERROR(ResizeShapeFn(c));
167       ShapeHandle min_shape;
168       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_shape));
169       ShapeHandle max_shape;
170       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_shape));
171       c->set_output(1, c->MakeShape({}));
172       c->set_output(2, c->MakeShape({}));
173       return Status::OK();
174     });
175 
176 // --------------------------------------------------------------------------
177 REGISTER_OP("ResizeBilinearGrad")
178     .Input("grads: float")
179     .Input("original_image: T")
180     .Output("output: T")
181     .Attr("T: {float, bfloat16, half, double}")
182     .Attr("align_corners: bool = false")
__anonc970e8b60402(InferenceContext* c) 183     .SetShapeFn([](InferenceContext* c) {
184       c->set_output(0, c->input(1));
185       return Status::OK();
186     });
187 
188 // --------------------------------------------------------------------------
189 REGISTER_OP("ResizeNearestNeighbor")
190     .Input("images: T")
191     .Input("size: int32")
192     .Output("resized_images: T")
193     .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
194     .Attr("align_corners: bool = false")
195     .SetShapeFn(ResizeShapeFn);
196 
197 // --------------------------------------------------------------------------
198 REGISTER_OP("ResizeNearestNeighborGrad")
199     .Input("grads: T")
200     .Input("size: int32")
201     .Output("output: T")
202     .Attr("T: {uint8, int8, int32, half, float, double}")
203     .Attr("align_corners: bool = false")
__anonc970e8b60502(InferenceContext* c) 204     .SetShapeFn([](InferenceContext* c) {
205       ShapeHandle input;
206       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
207       ShapeHandle unused;
208       DimensionHandle unused_dim;
209       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
210       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim));
211       const Tensor* size = c->input_tensor(1);
212       if (size == nullptr) {
213         TF_RETURN_IF_ERROR(c->ReplaceDim(input, 1, c->UnknownDim(), &input));
214         TF_RETURN_IF_ERROR(c->ReplaceDim(input, 2, c->UnknownDim(), &input));
215       } else {
216         auto size_vec = size->vec<int32>();
217         TF_RETURN_IF_ERROR(
218             c->ReplaceDim(input, 1, c->MakeDim(size_vec(0)), &input));
219         TF_RETURN_IF_ERROR(
220             c->ReplaceDim(input, 2, c->MakeDim(size_vec(1)), &input));
221       }
222       c->set_output(0, input);
223       return Status::OK();
224     });
225 
226 // --------------------------------------------------------------------------
227 REGISTER_OP("RandomCrop")
228     .Input("image: T")
229     .Input("size: int64")
230     .Output("output: T")
231     .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
232     .Attr("seed: int = 0")
233     .Attr("seed2: int = 0")
234     .SetIsStateful()
235     .Deprecated(8, "Random crop is now pure Python")
__anonc970e8b60602(InferenceContext* c) 236     .SetShapeFn([](InferenceContext* c) {
237       ShapeHandle image;
238       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &image));
239       DimensionHandle channels = c->Dim(image, -1);
240 
241       ShapeHandle unused;
242       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused));
243 
244       const Tensor* size = c->input_tensor(1);
245       DimensionHandle h;
246       DimensionHandle w;
247       if (size == nullptr) {
248         h = c->UnknownDim();
249         w = c->UnknownDim();
250       } else {
251         auto size_vec = size->vec<int64>();
252         h = c->MakeDim(size_vec(0));
253         w = c->MakeDim(size_vec(1));
254       }
255       c->set_output(0, c->MakeShape({h, w, channels}));
256       return Status::OK();
257     });
258 // TODO(shlens): Support variable rank in RandomCrop.
259 
260 // --------------------------------------------------------------------------
261 REGISTER_OP("DecodeJpeg")
262     .Input("contents: string")
263     .Attr("channels: int = 0")
264     .Attr("ratio: int = 1")
265     .Attr("fancy_upscaling: bool = true")
266     .Attr("try_recover_truncated: bool = false")
267     .Attr("acceptable_fraction: float = 1.0")
268     .Attr("dct_method: string = ''")
269     .Output("image: uint8")
270     .SetShapeFn(DecodeImageShapeFn);
271 
272 // --------------------------------------------------------------------------
273 REGISTER_OP("DecodeAndCropJpeg")
274     .Input("contents: string")
275     .Input("crop_window: int32")
276     .Attr("channels: int = 0")
277     .Attr("ratio: int = 1")
278     .Attr("fancy_upscaling: bool = true")
279     .Attr("try_recover_truncated: bool = false")
280     .Attr("acceptable_fraction: float = 1.0")
281     .Attr("dct_method: string = ''")
282     .Output("image: uint8")
__anonc970e8b60702(InferenceContext* c) 283     .SetShapeFn([](InferenceContext* c) {
284       ShapeHandle unused;
285       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
286       DimensionHandle channels_dim = c->UnknownDim();
287       DimensionHandle h = c->UnknownDim();
288       DimensionHandle w = c->UnknownDim();
289 
290       int32 channels;
291       TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
292       if (channels != 0) {
293         if (channels < 0) {
294           return errors::InvalidArgument("channels must be non-negative, got ",
295                                          channels);
296         }
297         channels_dim = c->MakeDim(channels);
298       }
299 
300       DimensionHandle unused_dim;
301       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
302       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim));
303 
304       const Tensor* crop_window = c->input_tensor(1);
305       if (crop_window != nullptr) {
306         auto crop_window_vec = crop_window->vec<int32>();
307         h = c->MakeDim(crop_window_vec(2));
308         w = c->MakeDim(crop_window_vec(3));
309       }
310       c->set_output(0, c->MakeShape({h, w, channels_dim}));
311       return Status::OK();
312     });
313 
314 // --------------------------------------------------------------------------
315 REGISTER_OP("EncodeJpeg")
316     .Input("image: uint8")
317     .Attr("format: {'', 'grayscale', 'rgb'} = ''")
318     .Attr("quality: int = 95")
319     .Attr("progressive: bool = false")
320     .Attr("optimize_size: bool = false")
321     .Attr("chroma_downsampling: bool = true")
322     .Attr("density_unit: {'in', 'cm'} = 'in'")
323     .Attr("x_density: int = 300")
324     .Attr("y_density: int = 300")
325     .Attr("xmp_metadata: string = ''")
326     .Output("contents: string")
327     .SetShapeFn(EncodeImageShapeFn);
328 
329 // --------------------------------------------------------------------------
330 REGISTER_OP("ExtractJpegShape")
331     .Input("contents: string")
332     .Output("image_shape: output_type")
333     .Attr("output_type: {int32, int64} = DT_INT32")
__anonc970e8b60802(InferenceContext* c) 334     .SetShapeFn([](InferenceContext* c) {
335       ShapeHandle unused;
336       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
337       c->set_output(0, c->Vector(3));
338       return Status::OK();
339     });
340 
341 // --------------------------------------------------------------------------
342 REGISTER_OP("AdjustContrast")
343     .Input("images: T")
344     .Input("contrast_factor: float")
345     .Input("min_value: float")
346     .Input("max_value: float")
347     .Output("output: float")
348     .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
349     .Deprecated(2, "Use AdjustContrastv2 instead")
__anonc970e8b60902(InferenceContext* c) 350     .SetShapeFn([](InferenceContext* c) {
351       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
352     });
353 
354 // --------------------------------------------------------------------------
355 REGISTER_OP("AdjustContrastv2")
356     .Input("images: float")
357     .Input("contrast_factor: float")
358     .Output("output: float")
__anonc970e8b60a02(InferenceContext* c) 359     .SetShapeFn([](InferenceContext* c) {
360       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
361     });
362 
363 // --------------------------------------------------------------------------
364 REGISTER_OP("AdjustHue")
365     .Input("images: float")
366     .Input("delta: float")
367     .Output("output: float")
__anonc970e8b60b02(InferenceContext* c) 368     .SetShapeFn([](InferenceContext* c) {
369       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
370     });
371 
372 // --------------------------------------------------------------------------
373 REGISTER_OP("AdjustSaturation")
374     .Input("images: float")
375     .Input("scale: float")
376     .Output("output: float")
__anonc970e8b60c02(InferenceContext* c) 377     .SetShapeFn([](InferenceContext* c) {
378       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
379     });
380 
381 // --------------------------------------------------------------------------
382 REGISTER_OP("DecodePng")
383     .Input("contents: string")
384     .Attr("channels: int = 0")
385     .Attr("dtype: {uint8, uint16} = DT_UINT8")
386     .Output("image: dtype")
387     .SetShapeFn(DecodeImageShapeFn);
388 
389 // --------------------------------------------------------------------------
390 REGISTER_OP("EncodePng")
391     .Attr("compression: int = -1")
392     .Attr("T: {uint8, uint16} = DT_UINT8")
393     .Input("image: T")
394     .Output("contents: string")
395     .SetShapeFn(EncodeImageShapeFn);
396 
397 // --------------------------------------------------------------------------
398 REGISTER_OP("DecodeBmp")
399     .Input("contents: string")
400     .Output("image: uint8")
401     .Attr("channels: int = 0")
402     .SetShapeFn(DecodeImageShapeFn);
403 
404 // --------------------------------------------------------------------------
405 REGISTER_OP("DecodeGif")
406     .Input("contents: string")
407     .Output("image: uint8")
__anonc970e8b60d02(InferenceContext* c) 408     .SetShapeFn([](InferenceContext* c) {
409       ShapeHandle unused;
410       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
411       c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
412                                      InferenceContext::kUnknownDim,
413                                      InferenceContext::kUnknownDim, 3}));
414       return Status::OK();
415     });
416 
417 // --------------------------------------------------------------------------
418 REGISTER_OP("RGBToHSV")
419     .Input("images: T")
420     .Output("output: T")
421     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
422     .SetShapeFn(ColorspaceShapeFn);
423 
424 // --------------------------------------------------------------------------
425 REGISTER_OP("HSVToRGB")
426     .Input("images: T")
427     .Output("output: T")
428     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
429     .SetShapeFn(ColorspaceShapeFn);
430 
431 // --------------------------------------------------------------------------
432 REGISTER_OP("DrawBoundingBoxes")
433     .Input("images: T")
434     .Input("boxes: float")
435     .Output("output: T")
436     .Attr("T: {float, half} = DT_FLOAT")
__anonc970e8b60e02(InferenceContext* c) 437     .SetShapeFn([](InferenceContext* c) {
438       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
439     });
440 
441 // --------------------------------------------------------------------------
442 REGISTER_OP("SampleDistortedBoundingBox")
443     .Input("image_size: T")
444     .Input("bounding_boxes: float")
445     .Output("begin: T")
446     .Output("size: T")
447     .Output("bboxes: float")
448     .Attr("T: {uint8, int8, int16, int32, int64}")
449     .Attr("seed: int = 0")
450     .Attr("seed2: int = 0")
451     .Attr("min_object_covered: float = 0.1")
452     .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
453     .Attr("area_range: list(float) = [0.05, 1.0]")
454     .Attr("max_attempts: int = 100")
455     .Attr("use_image_if_no_bounding_boxes: bool = false")
456     .SetIsStateful()
__anonc970e8b60f02(InferenceContext* c) 457     .SetShapeFn([](InferenceContext* c) {
458       // Get inputs and validate ranks.
459       ShapeHandle image_size;
460       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
461       ShapeHandle bounding_boxes;
462       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
463       // image_size: 1-D with [height, width, channels]
464       // bounding_boxes: 3-D with shape [batch, N, 4]
465       DimensionHandle unused;
466       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
467       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
468 
469       c->set_output(0, c->Vector(3));
470       c->set_output(1, c->Vector(3));
471       c->set_output(2, c->MakeShape({1, 1, 4}));
472       return Status::OK();
473     });
474 
475 REGISTER_OP("SampleDistortedBoundingBoxV2")
476     .Input("image_size: T")
477     .Input("bounding_boxes: float")
478     .Input("min_object_covered: float")
479     .Output("begin: T")
480     .Output("size: T")
481     .Output("bboxes: float")
482     .Attr("T: {uint8, int8, int16, int32, int64}")
483     .Attr("seed: int = 0")
484     .Attr("seed2: int = 0")
485     .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
486     .Attr("area_range: list(float) = [0.05, 1.0]")
487     .Attr("max_attempts: int = 100")
488     .Attr("use_image_if_no_bounding_boxes: bool = false")
489     .SetIsStateful()
__anonc970e8b61002(InferenceContext* c) 490     .SetShapeFn([](InferenceContext* c) {
491       // Get inputs and validate ranks.
492       ShapeHandle image_size;
493       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
494       ShapeHandle bounding_boxes;
495       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
496       ShapeHandle min_object_covered;
497       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
498       // image_size: 1-D with [height, width, channels]
499       // bounding_boxes: 3-D with shape [batch, N, 4]
500       DimensionHandle unused;
501       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
502       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
503 
504       c->set_output(0, c->Vector(3));
505       c->set_output(1, c->Vector(3));
506       c->set_output(2, c->MakeShape({1, 1, 4}));
507       return Status::OK();
508     });
509 
510 // --------------------------------------------------------------------------
511 
512 // glimpse = extract_glimpse(input, size, offsets) extract the glimpse
513 // of size `size` centered at location `offsets` from the input tensor
514 // `input`.
515 //
516 // REQUIRES: input.dims() == 4
517 //
518 REGISTER_OP("ExtractGlimpse")
519     .Input("input: float")
520     .Input("size: int32")
521     .Input("offsets: float")
522     .Output("glimpse: float")
523     .Attr("centered: bool = true")
524     .Attr("normalized: bool = true")
525     .Attr("uniform_noise: bool = true")
__anonc970e8b61102(InferenceContext* c) 526     .SetShapeFn([](InferenceContext* c) {
527       ShapeHandle input;
528       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
529       ShapeHandle offsets;
530       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
531 
532       DimensionHandle batch_dim;
533       TF_RETURN_IF_ERROR(
534           c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
535       DimensionHandle unused;
536       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
537 
538       return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
539                                    c->Dim(input, 3));
540     });
541 
542 // --------------------------------------------------------------------------
543 
544 REGISTER_OP("CropAndResize")
545     .Input("image: T")
546     .Input("boxes: float")
547     .Input("box_ind: int32")
548     .Input("crop_size: int32")
549     .Output("crops: float")
550     .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
551     .Attr("method: {'bilinear'} = 'bilinear'")
552     .Attr("extrapolation_value: float = 0")
__anonc970e8b61202(InferenceContext* c) 553     .SetShapeFn([](InferenceContext* c) {
554       // Get inputs and validate ranks.
555       ShapeHandle input;
556       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
557       ShapeHandle boxes;
558       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes));
559       ShapeHandle box_ind;
560       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind));
561 
562       // boxes[0] and box_ind[0] are both num_boxes.
563       DimensionHandle num_boxes_dim;
564       TF_RETURN_IF_ERROR(
565           c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim));
566 
567       // boxes.dim(1) is 4.
568       DimensionHandle unused;
569       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
570 
571       return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */,
572                                    c->Dim(input, 3));
573     });
574 
575 REGISTER_OP("CropAndResizeGradImage")
576     .Input("grads: float")
577     .Input("boxes: float")
578     .Input("box_ind: int32")
579     .Input("image_size: int32")
580     .Output("output: T")
581     .Attr("T: {float, half, double}")
582     .Attr("method: {'bilinear'} = 'bilinear'")
__anonc970e8b61302(InferenceContext* c) 583     .SetShapeFn([](InferenceContext* c) {
584       ShapeHandle out;
585       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
586       TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out));
587       c->set_output(0, out);
588       return Status::OK();
589     });
590 
591 REGISTER_OP("CropAndResizeGradBoxes")
592     .Input("grads: float")
593     .Input("image: T")
594     .Input("boxes: float")
595     .Input("box_ind: int32")
596     .Output("output: float")
597     .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
598     .Attr("method: {'bilinear'} = 'bilinear'")
__anonc970e8b61402(InferenceContext* c) 599     .SetShapeFn([](InferenceContext* c) {
600       c->set_output(0, c->input(2));
601       return Status::OK();
602     });
603 
604 // --------------------------------------------------------------------------
605 
606 REGISTER_OP("NonMaxSuppression")
607     .Input("boxes: float")
608     .Input("scores: float")
609     .Input("max_output_size: int32")
610     .Output("selected_indices: int32")
611     .Attr("iou_threshold: float = 0.5")
__anonc970e8b61502(InferenceContext* c) 612     .SetShapeFn([](InferenceContext* c) {
613       // Get inputs and validate ranks.
614       ShapeHandle boxes;
615       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
616       ShapeHandle scores;
617       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
618       ShapeHandle max_output_size;
619       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
620       // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
621       DimensionHandle unused;
622       // The boxes[0] and scores[0] are both num_boxes.
623       TF_RETURN_IF_ERROR(
624           c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
625       // The boxes[1] is 4.
626       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
627 
628       c->set_output(0, c->Vector(c->UnknownDim()));
629       return Status::OK();
630     });
631 
632 REGISTER_OP("NonMaxSuppressionV2")
633     .Input("boxes: float")
634     .Input("scores: float")
635     .Input("max_output_size: int32")
636     .Input("iou_threshold: float")
637     .Output("selected_indices: int32")
__anonc970e8b61602(InferenceContext* c) 638     .SetShapeFn([](InferenceContext* c) {
639       // Get inputs and validate ranks.
640       ShapeHandle boxes;
641       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
642       ShapeHandle scores;
643       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
644       ShapeHandle max_output_size;
645       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
646       ShapeHandle iou_threshold;
647       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
648       // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
649       DimensionHandle unused;
650       // The boxes[0] and scores[0] are both num_boxes.
651       TF_RETURN_IF_ERROR(
652           c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
653       // The boxes[1] is 4.
654       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
655 
656       c->set_output(0, c->Vector(c->UnknownDim()));
657       return Status::OK();
658     });
659 
660 }  // namespace tensorflow
661