• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/crop_and_resize_op.h"
21 
22 #include <functional>
23 #include <string>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/bounds_check.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 #if GOOGLE_CUDA
38 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
39 #include "tensorflow/core/platform/cuda.h"
40 #include "tensorflow/core/platform/stream_executor.h"
41 
42 using stream_executor::cuda::ScopedActivateExecutorContext;
43 #endif  // GOOGLE_CUDA
44 
45 namespace tensorflow {
46 namespace {
47 
48 typedef Eigen::ThreadPoolDevice CPUDevice;
49 typedef Eigen::GpuDevice GPUDevice;
50 using Callback = std::function<void()>;
51 
ParseAndCheckBoxSizes(const Tensor & boxes,const Tensor & box_index,int * num_boxes)52 static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
53                                            const Tensor& box_index,
54                                            int* num_boxes) {
55   if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
56     *num_boxes = 0;
57     return Status::OK();
58   }
59   // The shape of 'boxes' is [num_boxes, 4].
60   if (boxes.dims() != 2) {
61     return errors::InvalidArgument("boxes must be 2-D",
62                                    boxes.shape().DebugString());
63   }
64   *num_boxes = boxes.dim_size(0);
65   if (boxes.dim_size(1) != 4) {
66     return errors::InvalidArgument("boxes must have 4 columns");
67   }
68   // The shape of 'box_index' is [num_boxes].
69   if (box_index.dims() != 1) {
70     return errors::InvalidArgument("box_index must be 1-D",
71                                    box_index.shape().DebugString());
72   }
73   if (box_index.dim_size(0) != *num_boxes) {
74     return errors::InvalidArgument("box_index has incompatible shape");
75   }
76   return Status::OK();
77 }
78 
79 // Conditionally calls the compute callback if all values in box_index are in
80 // [0, batch_size) then calls done.
81 template <typename Device>
82 inline void RunIfBoxIndexIsValid(
83     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
84     int batch_size, const Callback& compute, const Callback& done);
85 
86 // Specialization of CheckValidBoxIndex for a CPUDevice.
87 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)88 inline void RunIfBoxIndexIsValid<CPUDevice>(
89     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
90     int batch_size, const Callback& compute, const Callback& done) {
91   const int num_boxes = box_index.dimension(0);
92   for (int b = 0; b < num_boxes; ++b) {
93     OP_REQUIRES_ASYNC(
94         context, FastBoundsCheck(box_index(b), batch_size),
95         errors::OutOfRange("box_index has values outside [0, batch_size)"),
96         done);
97   }
98   if (compute) {
99     compute();
100   }
101   if (done) {
102     done();
103   }
104 }
105 
106 }  // namespace
107 
108 template <typename Device, typename T>
109 class CropAndResizeOp : public AsyncOpKernel {
110  public:
CropAndResizeOp(OpKernelConstruction * context)111   explicit CropAndResizeOp(OpKernelConstruction* context)
112       : AsyncOpKernel(context) {
113     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
114     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
115                 errors::InvalidArgument(
116                     "method must be 'bilinear' or 'nearest'", method_));
117     OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value",
118                                              &extrapolation_value_));
119   }
120 
ComputeAsync(OpKernelContext * context,DoneCallback done)121   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
122     // The shape of 'image' is [batch_size, image_height, image_width,
123     // channels].
124     const Tensor& image = context->input(0);
125     // The shape of 'boxes' is [num_boxes, 4].
126     const Tensor& boxes = context->input(1);
127     // The shape of 'box_index' is [num_boxes].
128     const Tensor& box_index = context->input(2);
129     // The shape of 'crop_size' is [2].
130     const Tensor& crop_size = context->input(3);
131 
132     // Validate inputs dimensions.
133     OP_REQUIRES_ASYNC(context, image.dims() == 4,
134                       errors::InvalidArgument("input image must be 4-D",
135                                               image.shape().DebugString()),
136                       done);
137     const int batch_size = image.dim_size(0);
138     const int image_height = image.dim_size(1);
139     const int image_width = image.dim_size(2);
140     const int depth = image.dim_size(3);
141     OP_REQUIRES_ASYNC(
142         context, image_height > 0 && image_width > 0,
143         errors::InvalidArgument("image dimensions must be positive"), done);
144     int num_boxes = 0;
145     OP_REQUIRES_OK_ASYNC(
146         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
147 
148     OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
149                       errors::InvalidArgument("crop_size must be 1-D",
150                                               crop_size.shape().DebugString()),
151                       done);
152     OP_REQUIRES_ASYNC(
153         context, crop_size.dim_size(0) == 2,
154         errors::InvalidArgument("crop_size must have two elements",
155                                 crop_size.shape().DebugString()),
156         done);
157 
158     // Copy and validate crop sizes.
159     auto crop_size_vec = crop_size.vec<int32>();
160     const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
161     const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
162     OP_REQUIRES_ASYNC(
163         context, crop_height > 0 && crop_width > 0,
164         errors::InvalidArgument("crop dimensions must be positive"), done);
165 
166     // Allocate output tensor.
167     Tensor* output = nullptr;
168     OP_REQUIRES_OK_ASYNC(
169         context,
170         context->allocate_output(
171             0, TensorShape({num_boxes, crop_height, crop_width, depth}),
172             &output),
173         done);
174 
175     auto compute_callback = [this, context, output]() {
176       const Tensor& image = context->input(0);
177       const Tensor& boxes = context->input(1);
178       const Tensor& box_index = context->input(2);
179       const bool status = functor::CropAndResize<Device, T>()(
180           context, image.tensor<T, 4>(), boxes.tensor<float, 2>(),
181           box_index.tensor<int32, 1>(), method_, extrapolation_value_,
182           output->tensor<float, 4>());
183       if (!status) {
184         context->SetStatus(
185             errors::Internal("Failed launch CropAndResizeKernel."));
186       }
187     };
188 
189     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
190                                  batch_size, std::move(compute_callback),
191                                  std::move(done));
192   }
193 
194  private:
195   float extrapolation_value_;
196   string method_;
197 };
198 
199 // Partial specialization of CropAndResize functor for a CPUDevice.
200 namespace functor {
201 template <typename T>
202 struct CropAndResize<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResize203   bool operator()(const OpKernelContext* context,
204                   typename TTypes<T, 4>::ConstTensor image,
205                   typename TTypes<float, 2>::ConstTensor boxes,
206                   typename TTypes<int32, 1>::ConstTensor box_index,
207                   const string& method_name, float extrapolation_value,
208                   typename TTypes<float, 4>::Tensor crops) {
209     const int batch_size = image.dimension(0);
210     const int image_height = image.dimension(1);
211     const int image_width = image.dimension(2);
212 
213     const int num_boxes = crops.dimension(0);
214     const int crop_height = crops.dimension(1);
215     const int crop_width = crops.dimension(2);
216     const int depth = crops.dimension(3);
217 
218     // Sharding across boxes.
219     auto CropAndResizePerBox = [&](int start_box, int limit_box) {
220       for (int b = start_box; b < limit_box; ++b) {
221         const float y1 = boxes(b, 0);
222         const float x1 = boxes(b, 1);
223         const float y2 = boxes(b, 2);
224         const float x2 = boxes(b, 3);
225 
226         const int32 b_in = box_index(b);
227         if (!FastBoundsCheck(b_in, batch_size)) {
228           continue;
229         }
230 
231         const float height_scale =
232             (crop_height > 1)
233                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
234                 : 0;
235         const float width_scale =
236             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
237                              : 0;
238 
239         for (int y = 0; y < crop_height; ++y) {
240           const float in_y = (crop_height > 1)
241                                  ? y1 * (image_height - 1) + y * height_scale
242                                  : 0.5 * (y1 + y2) * (image_height - 1);
243           if (in_y < 0 || in_y > image_height - 1) {
244             for (int x = 0; x < crop_width; ++x) {
245               for (int d = 0; d < depth; ++d) {
246                 crops(b, y, x, d) = extrapolation_value;
247               }
248             }
249             continue;
250           }
251           if (method_name == "bilinear") {
252             const int top_y_index = floorf(in_y);
253             const int bottom_y_index = ceilf(in_y);
254             const float y_lerp = in_y - top_y_index;
255 
256             for (int x = 0; x < crop_width; ++x) {
257               const float in_x = (crop_width > 1)
258                                      ? x1 * (image_width - 1) + x * width_scale
259                                      : 0.5 * (x1 + x2) * (image_width - 1);
260               if (in_x < 0 || in_x > image_width - 1) {
261                 for (int d = 0; d < depth; ++d) {
262                   crops(b, y, x, d) = extrapolation_value;
263                 }
264                 continue;
265               }
266               const int left_x_index = floorf(in_x);
267               const int right_x_index = ceilf(in_x);
268               const float x_lerp = in_x - left_x_index;
269 
270               for (int d = 0; d < depth; ++d) {
271                 const float top_left(static_cast<float>(
272                     image(b_in, top_y_index, left_x_index, d)));
273                 const float top_right(static_cast<float>(
274                     image(b_in, top_y_index, right_x_index, d)));
275                 const float bottom_left(static_cast<float>(
276                     image(b_in, bottom_y_index, left_x_index, d)));
277                 const float bottom_right(static_cast<float>(
278                     image(b_in, bottom_y_index, right_x_index, d)));
279                 const float top = top_left + (top_right - top_left) * x_lerp;
280                 const float bottom =
281                     bottom_left + (bottom_right - bottom_left) * x_lerp;
282                 crops(b, y, x, d) = top + (bottom - top) * y_lerp;
283               }
284             }
285           } else {  // method == "nearest"
286             for (int x = 0; x < crop_width; ++x) {
287               const float in_x = (crop_width > 1)
288                                      ? x1 * (image_width - 1) + x * width_scale
289                                      : 0.5 * (x1 + x2) * (image_width - 1);
290               if (in_x < 0 || in_x > image_width - 1) {
291                 for (int d = 0; d < depth; ++d) {
292                   crops(b, y, x, d) = extrapolation_value;
293                 }
294                 continue;
295               }
296               const int closest_x_index = roundf(in_x);
297               const int closest_y_index = roundf(in_y);
298               for (int d = 0; d < depth; ++d) {
299                 crops(b, y, x, d) = static_cast<float>(
300                     image(b_in, closest_y_index, closest_x_index, d));
301               }
302             }
303           }
304         }
305       }
306     };
307 
308     // A rough estimation of the cost for each cropped box.
309     double cost_per_pixel =
310         depth * (Eigen::TensorOpCost::AddCost<float>() * 6 +
311                  Eigen::TensorOpCost::MulCost<float>() * 3 +
312                  Eigen::TensorOpCost::CastCost<T, float>() * 4) +
313         (Eigen::TensorOpCost::AddCost<float>() * 2 +
314          Eigen::TensorOpCost::AddCost<float>() * 3);
315     if (method_name == "nearest") {
316       cost_per_pixel = depth * Eigen::TensorOpCost::CastCost<T, float>() +
317                        Eigen::TensorOpCost::AddCost<float>() * 4 +
318                        Eigen::TensorOpCost::MulCost<float>() * 4;
319     }
320     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
321 
322     const DeviceBase::CpuWorkerThreads& worker_threads =
323         *(context->device()->tensorflow_cpu_worker_threads());
324     Shard(worker_threads.num_threads, worker_threads.workers, num_boxes,
325           cost_per_box, CropAndResizePerBox);
326 
327     return true;
328   }
329 };
330 
331 }  // namespace functor
332 
333 template <typename Device, typename T>
334 class CropAndResizeGradImageOp : public AsyncOpKernel {
335  public:
CropAndResizeGradImageOp(OpKernelConstruction * context)336   explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
337       : AsyncOpKernel(context) {
338     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
339     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
340                 errors::InvalidArgument(
341                     "method must be 'bilinear' or 'nearest'", method_));
342   }
343 
ComputeAsync(OpKernelContext * context,DoneCallback done)344   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
345     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
346     const Tensor& grads = context->input(0);
347     // The shape of 'boxes' is [num_boxes, 4].
348     const Tensor& boxes = context->input(1);
349     // The shape of 'box_index' is [num_boxes].
350     const Tensor& box_index = context->input(2);
351     // The shape of 'image_size' is [4].
352     const Tensor& image_size = context->input(3);
353 
354     // Validate input shapes.
355     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
356                       errors::InvalidArgument("grads image must be 4-D",
357                                               grads.shape().DebugString()),
358                       done);
359     const int crop_height = grads.dim_size(1);
360     const int crop_width = grads.dim_size(2);
361     OP_REQUIRES_ASYNC(
362         context, crop_height > 0 && crop_width > 0,
363         errors::InvalidArgument("grads dimensions must be positive"), done);
364     int num_boxes = 0;
365     OP_REQUIRES_OK_ASYNC(
366         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
367     OP_REQUIRES_ASYNC(
368         context, grads.dim_size(0) == num_boxes,
369         errors::InvalidArgument("boxes and grads have incompatible shape"),
370         done);
371 
372     OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
373                       errors::InvalidArgument("image_size must be 1-D",
374                                               image_size.shape().DebugString()),
375                       done);
376     OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
377                       errors::InvalidArgument("image_size must have 4 elements",
378                                               image_size.shape().DebugString()),
379                       done);
380     auto image_size_vec = image_size.vec<int32>();
381     const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
382     const int image_height = internal::SubtleMustCopy(image_size_vec(1));
383     const int image_width = internal::SubtleMustCopy(image_size_vec(2));
384     const int depth = internal::SubtleMustCopy(image_size_vec(3));
385     OP_REQUIRES_ASYNC(
386         context, image_height > 0 && image_width > 0,
387         errors::InvalidArgument("image dimensions must be positive"), done);
388     OP_REQUIRES_ASYNC(
389         context, grads.dim_size(3) == depth,
390         errors::InvalidArgument("image_size and grads are incompatible"), done);
391 
392     // Allocate output tensor.
393     Tensor* output = nullptr;
394     OP_REQUIRES_OK_ASYNC(
395         context,
396         context->allocate_output(
397             0, TensorShape({batch_size, image_height, image_width, depth}),
398             &output),
399         done);
400 
401     auto compute_callback = [this, context, output]() {
402       const Tensor& grads = context->input(0);
403       const Tensor& boxes = context->input(1);
404       const Tensor& box_index = context->input(2);
405       const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
406           context->eigen_device<Device>(), grads.tensor<float, 4>(),
407           boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
408           output->tensor<T, 4>(), method_);
409       if (!status) {
410         context->SetStatus(errors::Internal(
411             "Failed launch CropAndResizeBackpropImage kernel."));
412       }
413     };
414 
415     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
416                                  batch_size, std::move(compute_callback),
417                                  std::move(done));
418   }
419 
420  private:
421   string method_;
422 };
423 
424 // Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice.
425 namespace functor {
426 template <typename T>
427 struct CropAndResizeBackpropImage<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropImage428   bool operator()(const CPUDevice& d,
429                   typename TTypes<float, 4>::ConstTensor grads,
430                   typename TTypes<float, 2>::ConstTensor boxes,
431                   typename TTypes<int32, 1>::ConstTensor box_index,
432                   typename TTypes<T, 4>::Tensor grads_image,
433                   const string& method_name) {
434     const int batch_size = grads_image.dimension(0);
435     const int image_height = grads_image.dimension(1);
436     const int image_width = grads_image.dimension(2);
437 
438     const int num_boxes = grads.dimension(0);
439     const int crop_height = grads.dimension(1);
440     const int crop_width = grads.dimension(2);
441     const int depth = grads.dimension(3);
442 
443     grads_image.setZero();
444 
445     for (int b = 0; b < num_boxes; ++b) {
446       const float y1 = boxes(b, 0);
447       const float x1 = boxes(b, 1);
448       const float y2 = boxes(b, 2);
449       const float x2 = boxes(b, 3);
450 
451       const int32 b_in = box_index(b);
452       if (!FastBoundsCheck(b_in, batch_size)) {
453         continue;
454       }
455 
456       const float height_scale =
457           (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
458                             : 0;
459       const float width_scale =
460           (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
461                            : 0;
462 
463       for (int y = 0; y < crop_height; ++y) {
464         const float in_y = (crop_height > 1)
465                                ? y1 * (image_height - 1) + y * height_scale
466                                : 0.5 * (y1 + y2) * (image_height - 1);
467         if (in_y < 0 || in_y > image_height - 1) {
468           continue;
469         }
470         const int top_y_index = floorf(in_y);
471         const int bottom_y_index = ceilf(in_y);
472         const float y_lerp = in_y - top_y_index;
473 
474         for (int x = 0; x < crop_width; ++x) {
475           const float in_x = (crop_width > 1)
476                                  ? x1 * (image_width - 1) + x * width_scale
477                                  : 0.5 * (x1 + x2) * (image_width - 1);
478           if (in_x < 0 || in_x > image_width - 1) {
479             continue;
480           }
481           if (method_name == "bilinear") {
482             const int left_x_index = floorf(in_x);
483             const int right_x_index = ceilf(in_x);
484             const float x_lerp = in_x - left_x_index;
485 
486             for (int d = 0; d < depth; ++d) {
487               const float dtop = (1 - y_lerp) * grads(b, y, x, d);
488               grads_image(b_in, top_y_index, left_x_index, d) +=
489                   static_cast<T>((1 - x_lerp) * dtop);
490               grads_image(b_in, top_y_index, right_x_index, d) +=
491                   static_cast<T>(x_lerp * dtop);
492               const float dbottom = y_lerp * grads(b, y, x, d);
493               grads_image(b_in, bottom_y_index, left_x_index, d) +=
494                   static_cast<T>((1 - x_lerp) * dbottom);
495               grads_image(b_in, bottom_y_index, right_x_index, d) +=
496                   static_cast<T>(x_lerp * dbottom);
497             }
498           } else {  // method_name == "nearest"
499             for (int d = 0; d < depth; ++d) {
500               int closest_x_index = roundf(in_x);
501               int closest_y_index = roundf(in_y);
502               grads_image(b_in, closest_y_index, closest_x_index, d) +=
503                   static_cast<T>(grads(b, y, x, d));
504             }
505           }
506         }
507       }
508     }
509     return true;
510   }
511 };
512 
513 }  // namespace functor
514 
515 template <typename Device, typename T>
516 class CropAndResizeGradBoxesOp : public AsyncOpKernel {
517  public:
CropAndResizeGradBoxesOp(OpKernelConstruction * context)518   explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
519       : AsyncOpKernel(context) {
520     string method;
521     OP_REQUIRES_OK(context, context->GetAttr("method", &method));
522     OP_REQUIRES(context, method == "bilinear",
523                 errors::InvalidArgument("method must be 'bilinear'", method));
524   }
525 
ComputeAsync(OpKernelContext * context,DoneCallback done)526   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
527     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
528     const Tensor& grads = context->input(0);
529     // The shape of 'boxes' is [num_boxes, 4].
530     const Tensor& boxes = context->input(2);
531     // The shape of 'box_index' is [num_boxes].
532     const Tensor& box_index = context->input(3);
533     // The shape of 'image' is [batch_size, image_height, image_width, depth].
534     const Tensor& image = context->input(1);
535 
536     // Validate input shapes.
537     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
538                       errors::InvalidArgument("grads image must be 4-D",
539                                               grads.shape().DebugString()),
540                       done);
541     const int crop_height = grads.dim_size(1);
542     const int crop_width = grads.dim_size(2);
543     const int depth = grads.dim_size(3);
544     OP_REQUIRES_ASYNC(
545         context, crop_height > 0 && crop_width > 0,
546         errors::InvalidArgument("grads dimensions must be positive"), done);
547 
548     OP_REQUIRES_ASYNC(context, image.dims() == 4,
549                       errors::InvalidArgument("input image must be 4-D",
550                                               image.shape().DebugString()),
551                       done);
552     const int batch_size = image.dim_size(0);
553     const int image_height = image.dim_size(1);
554     const int image_width = image.dim_size(2);
555     OP_REQUIRES_ASYNC(
556         context, image_height > 0 && image_width > 0,
557         errors::InvalidArgument("image dimensions must be positive"), done);
558     OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
559                       errors::InvalidArgument("image, grads depth differ"),
560                       done);
561 
562     int num_boxes = 0;
563     OP_REQUIRES_OK_ASYNC(
564         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
565 
566     OP_REQUIRES_ASYNC(
567         context, grads.dim_size(0) == num_boxes,
568         errors::InvalidArgument("boxes and grads have incompatible shape"),
569         done);
570 
571     // Allocate output tensor.
572     Tensor* output = nullptr;
573     OP_REQUIRES_OK_ASYNC(
574         context,
575         context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
576         done);
577 
578     auto compute_callback = [context, output]() {
579       const Tensor& grads = context->input(0);
580       const Tensor& image = context->input(1);
581       const Tensor& boxes = context->input(2);
582       const Tensor& box_index = context->input(3);
583       const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
584           context->eigen_device<Device>(), grads.tensor<float, 4>(),
585           image.tensor<T, 4>(), boxes.tensor<float, 2>(),
586           box_index.tensor<int32, 1>(), output->tensor<float, 2>());
587       if (!status) {
588         context->SetStatus(errors::Internal(
589             "Failed launch CropAndResizeBackpropBoxes kernel."));
590       }
591     };
592 
593     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
594                                  batch_size, std::move(compute_callback),
595                                  std::move(done));
596   }
597 };
598 
599 // Partial specialization of CropAndResizeBackpropBoxes functor for a CPUDevice.
600 namespace functor {
601 template <typename T>
602 struct CropAndResizeBackpropBoxes<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropBoxes603   bool operator()(const CPUDevice& d,
604                   typename TTypes<float, 4>::ConstTensor grads,
605                   typename TTypes<T, 4>::ConstTensor image,
606                   typename TTypes<float, 2>::ConstTensor boxes,
607                   typename TTypes<int32, 1>::ConstTensor box_index,
608                   typename TTypes<float, 2>::Tensor grads_boxes) {
609     const int batch_size = image.dimension(0);
610     const int image_height = image.dimension(1);
611     const int image_width = image.dimension(2);
612 
613     const int num_boxes = grads.dimension(0);
614     const int crop_height = grads.dimension(1);
615     const int crop_width = grads.dimension(2);
616     const int depth = grads.dimension(3);
617 
618     grads_boxes.setZero();
619 
620     for (int b = 0; b < num_boxes; ++b) {
621       const float y1 = boxes(b, 0);
622       const float x1 = boxes(b, 1);
623       const float y2 = boxes(b, 2);
624       const float x2 = boxes(b, 3);
625 
626       const int32 b_in = box_index(b);
627       if (!FastBoundsCheck(b_in, batch_size)) {
628         continue;
629       }
630 
631       const float height_ratio =
632           (crop_height > 1)
633               ? static_cast<float>(image_height - 1) / (crop_height - 1)
634               : 0;
635       const float width_ratio =
636           (crop_width > 1)
637               ? static_cast<float>(image_width - 1) / (crop_width - 1)
638               : 0;
639 
640       const float height_scale =
641           (crop_height > 1) ? (y2 - y1) * height_ratio : 0;
642       const float width_scale = (crop_width > 1) ? (x2 - x1) * width_ratio : 0;
643 
644       for (int y = 0; y < crop_height; ++y) {
645         const float in_y = (crop_height > 1)
646                                ? y1 * (image_height - 1) + y * height_scale
647                                : 0.5 * (y1 + y2) * (image_height - 1);
648         if (in_y < 0 || in_y > image_height - 1) {
649           continue;
650         }
651         const int top_y_index = floorf(in_y);
652         const int bottom_y_index = ceilf(in_y);
653         const float y_lerp = in_y - top_y_index;
654 
655         for (int x = 0; x < crop_width; ++x) {
656           const float in_x = (crop_width > 1)
657                                  ? x1 * (image_width - 1) + x * width_scale
658                                  : 0.5 * (x1 + x2) * (image_width - 1);
659           if (in_x < 0 || in_x > image_width - 1) {
660             continue;
661           }
662           const int left_x_index = floorf(in_x);
663           const int right_x_index = ceilf(in_x);
664           const float x_lerp = in_x - left_x_index;
665 
666           for (int d = 0; d < depth; ++d) {
667             const float top_left(
668                 static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
669             const float top_right(
670                 static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
671             const float bottom_left(static_cast<float>(
672                 image(b_in, bottom_y_index, left_x_index, d)));
673             const float bottom_right(static_cast<float>(
674                 image(b_in, bottom_y_index, right_x_index, d)));
675             // Compute the image gradient.
676             float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
677                                  x_lerp * (bottom_right - top_right);
678             float image_grad_x = (1 - y_lerp) * (top_right - top_left) +
679                                  y_lerp * (bottom_right - bottom_left);
680             // Modulate the image gradient with the incoming gradient.
681             const float top_grad = grads(b, y, x, d);
682             image_grad_y *= top_grad;
683             image_grad_x *= top_grad;
684             // dy1, dy2
685             if (crop_height > 1) {
686               grads_boxes(b, 0) +=
687                   image_grad_y * (image_height - 1 - y * height_ratio);
688               grads_boxes(b, 2) += image_grad_y * (y * height_ratio);
689             } else {
690               grads_boxes(b, 0) += image_grad_y * 0.5 * (image_height - 1);
691               grads_boxes(b, 2) += image_grad_y * 0.5 * (image_height - 1);
692             }
693             // dx1, dx2
694             if (crop_width > 1) {
695               grads_boxes(b, 1) +=
696                   image_grad_x * (image_width - 1 - x * width_ratio);
697               grads_boxes(b, 3) += image_grad_x * (x * width_ratio);
698             } else {
699               grads_boxes(b, 1) += image_grad_x * 0.5 * (image_width - 1);
700               grads_boxes(b, 3) += image_grad_x * 0.5 * (image_width - 1);
701             }
702           }
703         }
704       }
705     }
706     return true;
707   }
708 };
709 
710 }  // namespace functor
711 
712 #define REGISTER_KERNEL(T)                                \
713   REGISTER_KERNEL_BUILDER(Name("CropAndResize")           \
714                               .Device(DEVICE_CPU)         \
715                               .TypeConstraint<T>("T")     \
716                               .HostMemory("crop_size"),   \
717                           CropAndResizeOp<CPUDevice, T>); \
718                                                           \
719   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")  \
720                               .Device(DEVICE_CPU)         \
721                               .TypeConstraint<T>("T"),    \
722                           CropAndResizeGradBoxesOp<CPUDevice, T>);
723 
724 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
725 
726 #undef REGISTER_KERNEL
727 
728 #define REGISTER_KERNEL(T)                               \
729   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
730                               .Device(DEVICE_CPU)        \
731                               .TypeConstraint<T>("T")    \
732                               .HostMemory("image_size"), \
733                           CropAndResizeGradImageOp<CPUDevice, T>);
734 
735 TF_CALL_half(REGISTER_KERNEL);
736 TF_CALL_float(REGISTER_KERNEL);
737 TF_CALL_double(REGISTER_KERNEL);
738 
739 #undef REGISTER_KERNEL
740 
741 #if GOOGLE_CUDA
742 
743 // Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
744 namespace functor {
745 template <>
746 void CheckValidBoxIndexHelper<GPUDevice>::operator()(
747     const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
748     int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
749 extern template struct CheckValidBoxIndexHelper<GPUDevice>;
750 }  // namespace functor
751 
752 namespace {
753 
754 // Specialization of CheckValidBoxIndex for a GPUDevice.
755 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)756 inline void RunIfBoxIndexIsValid<GPUDevice>(
757     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
758     int batch_size, const Callback& compute, const Callback& done) {
759   const int num_boxes = box_index.dimension(0);
760   if (num_boxes == 0) {
761     compute();
762     done();
763     return;
764   }
765 
766   Tensor isvalid_dev_tensor;
767   OP_REQUIRES_OK_ASYNC(
768       context,
769       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
770                              &isvalid_dev_tensor),
771       done);
772   typename TTypes<bool, 0>::Tensor isvalid_dev =
773       isvalid_dev_tensor.tensor<bool, 0>();
774 
775   // Run the actual box check on the device.
776   functor::CheckValidBoxIndexHelper<GPUDevice>()(
777       context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
778 
779   // Copy the result back to the host.
780   auto* stream = context->op_device_context()->stream();
781   OP_REQUIRES_ASYNC(context, stream,
782                     errors::Internal("No GPU stream available."), done);
783   Tensor isvalid_host_tensor;
784   // Use pinned host memory on the host to avoid unnecessary
785   // synchronization.
786   AllocatorAttributes alloc_attr;
787   alloc_attr.set_on_host(true);
788   alloc_attr.set_gpu_compatible(true);
789   OP_REQUIRES_OK_ASYNC(
790       context,
791       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
792                              &isvalid_host_tensor, alloc_attr),
793       done);
794   se::DeviceMemoryBase wrapped(isvalid_dev.data(), sizeof(bool));
795   const bool status =
796       stream
797           ->ThenMemcpy(
798               isvalid_host_tensor.scalar<bool>().data() /* destination */,
799               wrapped /* source */, sizeof(bool))
800           .ok();
801   OP_REQUIRES_ASYNC(
802       context, status,
803       errors::Internal("Failed to launch copy of isvalid from device to host."),
804       done);
805 
806   // We capture both temporary tensors to prevent them from being deallocated
807   // when ComputeAsync returns and before the closure runs.
808   TensorReference isvalid_dev_ref(isvalid_dev_tensor);
809   auto wrapped_callback = [context, isvalid_host_tensor, isvalid_dev_ref,
810                            compute, done]() {
811     auto stream = context->op_device_context()->stream();
812     ScopedActivateExecutorContext scoped_activation{stream->parent()};
813     const bool isvalid = isvalid_host_tensor.scalar<bool>()();
814     isvalid_dev_ref.Unref();
815     OP_REQUIRES_ASYNC(
816         context, isvalid,
817         errors::OutOfRange("box_index has values outside [0, batch_size)"),
818         done);
819     compute();
820     done();
821   };
822 
823   context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
824       stream, wrapped_callback);
825 }
826 
827 }  // namespace
828 
829 #define REGISTER_KERNEL(T)                                         \
830   REGISTER_KERNEL_BUILDER(Name("CropAndResize")                    \
831                               .Device(DEVICE_GPU)                  \
832                               .TypeConstraint<T>("T")              \
833                               .HostMemory("crop_size"),            \
834                           CropAndResizeOp<GPUDevice, T>);          \
835                                                                    \
836   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage")           \
837                               .Device(DEVICE_GPU)                  \
838                               .TypeConstraint<T>("T")              \
839                               .HostMemory("image_size"),           \
840                           CropAndResizeGradImageOp<GPUDevice, T>); \
841                                                                    \
842   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")           \
843                               .Device(DEVICE_GPU)                  \
844                               .TypeConstraint<T>("T"),             \
845                           CropAndResizeGradBoxesOp<GPUDevice, T>);
846 
847 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
848 
849 #undef REGISTER_KERNEL
850 
851 #endif  // GOOGLE_CUDA
852 
853 }  // namespace tensorflow
854