1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/image/crop_and_resize_op.h"
21 
22 #include <functional>
23 #include <string>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/bounds_check.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_reference.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/work_sharder.h"
37 
38 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
40 #include "tensorflow/core/platform/stream_executor.h"
41 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
42 
43 #if GOOGLE_CUDA
44 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
45 using stream_executor::cuda::ScopedActivateExecutorContext;
46 #elif TENSORFLOW_USE_ROCM
47 #include "tensorflow/core/platform/rocm.h"
48 using stream_executor::rocm::ScopedActivateExecutorContext;
49 #endif
50 
51 namespace tensorflow {
52 namespace {
53 
54 typedef Eigen::ThreadPoolDevice CPUDevice;
55 typedef Eigen::GpuDevice GPUDevice;
56 using Callback = std::function<void()>;
57 
ParseAndCheckBoxSizes(const Tensor & boxes,const Tensor & box_index,int * num_boxes)58 static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
59                                            const Tensor& box_index,
60                                            int* num_boxes) {
61   if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
62     *num_boxes = 0;
63     return Status::OK();
64   }
65   // The shape of 'boxes' is [num_boxes, 4].
66   if (boxes.dims() != 2) {
67     return errors::InvalidArgument("boxes must be 2-D",
68                                    boxes.shape().DebugString());
69   }
70   *num_boxes = boxes.dim_size(0);
71   if (boxes.dim_size(1) != 4) {
72     return errors::InvalidArgument("boxes must have 4 columns");
73   }
74   // The shape of 'box_index' is [num_boxes].
75   if (box_index.dims() != 1) {
76     return errors::InvalidArgument("box_index must be 1-D",
77                                    box_index.shape().DebugString());
78   }
79   if (box_index.dim_size(0) != *num_boxes) {
80     return errors::InvalidArgument("box_index has incompatible shape");
81   }
82   return Status::OK();
83 }
84 
85 // Conditionally calls the compute callback if all values in box_index are in
86 // [0, batch_size) then calls done.
87 template <typename Device>
88 inline void RunIfBoxIndexIsValid(
89     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
90     int batch_size, const Callback& compute, const Callback& done);
91 
92 // Specialization of CheckValidBoxIndex for a CPUDevice.
93 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)94 inline void RunIfBoxIndexIsValid<CPUDevice>(
95     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
96     int batch_size, const Callback& compute, const Callback& done) {
97   const int num_boxes = box_index.dimension(0);
98   for (int b = 0; b < num_boxes; ++b) {
99     OP_REQUIRES_ASYNC(
100         context, FastBoundsCheck(box_index(b), batch_size),
101         errors::OutOfRange("box_index has values outside [0, batch_size)"),
102         done);
103   }
104   if (compute) {
105     compute();
106   }
107   if (done) {
108     done();
109   }
110 }
111 
112 }  // namespace
113 
114 template <typename Device, typename T>
115 class CropAndResizeOp : public AsyncOpKernel {
116  public:
CropAndResizeOp(OpKernelConstruction * context)117   explicit CropAndResizeOp(OpKernelConstruction* context)
118       : AsyncOpKernel(context) {
119     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
120     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
121                 errors::InvalidArgument(
122                     "method must be 'bilinear' or 'nearest'", method_));
123     OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value",
124                                              &extrapolation_value_));
125   }
126 
ComputeAsync(OpKernelContext * context,DoneCallback done)127   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
128     // The shape of 'image' is [batch_size, image_height, image_width,
129     // channels].
130     const Tensor& image = context->input(0);
131     // The shape of 'boxes' is [num_boxes, 4].
132     const Tensor& boxes = context->input(1);
133     // The shape of 'box_index' is [num_boxes].
134     const Tensor& box_index = context->input(2);
135     // The shape of 'crop_size' is [2].
136     const Tensor& crop_size = context->input(3);
137 
138     // Validate inputs dimensions.
139     OP_REQUIRES_ASYNC(context, image.dims() == 4,
140                       errors::InvalidArgument("input image must be 4-D",
141                                               image.shape().DebugString()),
142                       done);
143     const int batch_size = image.dim_size(0);
144     const int image_height = image.dim_size(1);
145     const int image_width = image.dim_size(2);
146     const int depth = image.dim_size(3);
147     OP_REQUIRES_ASYNC(
148         context, image_height > 0 && image_width > 0,
149         errors::InvalidArgument("image dimensions must be positive"), done);
150     int num_boxes = 0;
151     OP_REQUIRES_OK_ASYNC(
152         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
153 
154     OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
155                       errors::InvalidArgument("crop_size must be 1-D",
156                                               crop_size.shape().DebugString()),
157                       done);
158     OP_REQUIRES_ASYNC(
159         context, crop_size.dim_size(0) == 2,
160         errors::InvalidArgument("crop_size must have two elements",
161                                 crop_size.shape().DebugString()),
162         done);
163 
164     // Copy and validate crop sizes.
165     auto crop_size_vec = crop_size.vec<int32>();
166     const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
167     const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
168     OP_REQUIRES_ASYNC(
169         context, crop_height > 0 && crop_width > 0,
170         errors::InvalidArgument("crop dimensions must be positive"), done);
171 
172     // Allocate output tensor.
173     Tensor* output = nullptr;
174     OP_REQUIRES_OK_ASYNC(
175         context,
176         context->allocate_output(
177             0, TensorShape({num_boxes, crop_height, crop_width, depth}),
178             &output),
179         done);
180 
181     auto compute_callback = [this, context, output]() {
182       const Tensor& image = context->input(0);
183       const Tensor& boxes = context->input(1);
184       const Tensor& box_index = context->input(2);
185       const bool status = functor::CropAndResize<Device, T>()(
186           context, image.tensor<T, 4>(), boxes.tensor<float, 2>(),
187           box_index.tensor<int32, 1>(), method_, extrapolation_value_,
188           output->tensor<float, 4>());
189 
190       if (!status) {
191         context->SetStatus(
192             errors::Internal("Failed launch CropAndResizeKernel."));
193       }
194     };
195 
196     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
197                                  batch_size, std::move(compute_callback),
198                                  std::move(done));
199   }
200 
201  private:
202   float extrapolation_value_;
203   string method_;
204 };
205 
206 // Partial specialization of CropAndResize functor for a CPUDevice.
207 namespace functor {
208 template <typename T>
209 struct CropAndResize<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResize210   bool operator()(OpKernelContext* context,
211                   typename TTypes<T, 4>::ConstTensor image,
212                   typename TTypes<float, 2>::ConstTensor boxes,
213                   typename TTypes<int32, 1>::ConstTensor box_index,
214                   const string& method_name, float extrapolation_value,
215                   typename TTypes<float, 4>::Tensor crops) {
216     const int batch_size = image.dimension(0);
217     const int image_height = image.dimension(1);
218     const int image_width = image.dimension(2);
219 
220     const int num_boxes = crops.dimension(0);
221     const int crop_height = crops.dimension(1);
222     const int crop_width = crops.dimension(2);
223     const int depth = crops.dimension(3);
224 
225     // Since `functor::CropAndResize` operates on float, we first validate
226     // that we don't overflow (since overflow causes undefined behavior which
227     // could result in segfault in this scenario).
228     const Eigen::Tensor<bool, 0, Eigen::RowMajor> only_finite_elements =
229         boxes.isfinite().all();
230     if (!only_finite_elements()) {
231       context->SetStatus(errors::InvalidArgument(
232           "Boxes contains at least one element that is not finite"));
233       return false;
234     }
235 
236     // Sharding across boxes.
237     auto CropAndResizePerBox = [&](int64 start_box, int64 limit_box) {
238       for (int b = start_box; b < limit_box; ++b) {
239         const float y1 = boxes(b, 0);
240         const float x1 = boxes(b, 1);
241         const float y2 = boxes(b, 2);
242         const float x2 = boxes(b, 3);
243 
244         const int32 b_in = box_index(b);
245         if (!FastBoundsCheck(b_in, batch_size)) {
246           continue;
247         }
248 
249         const float height_scale =
250             (crop_height > 1)
251                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
252                 : 0;
253         const float width_scale =
254             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
255                              : 0;
256 
257         for (int y = 0; y < crop_height; ++y) {
258           const float in_y = (crop_height > 1)
259                                  ? y1 * (image_height - 1) + y * height_scale
260                                  : 0.5 * (y1 + y2) * (image_height - 1);
261           if (in_y < 0 || in_y > image_height - 1) {
262             for (int x = 0; x < crop_width; ++x) {
263               for (int d = 0; d < depth; ++d) {
264                 crops(b, y, x, d) = extrapolation_value;
265               }
266             }
267             continue;
268           }
269           if (method_name == "bilinear") {
270             const int top_y_index = floorf(in_y);
271             const int bottom_y_index = ceilf(in_y);
272             const float y_lerp = in_y - top_y_index;
273 
274             for (int x = 0; x < crop_width; ++x) {
275               const float in_x = (crop_width > 1)
276                                      ? x1 * (image_width - 1) + x * width_scale
277                                      : 0.5 * (x1 + x2) * (image_width - 1);
278               if (in_x < 0 || in_x > image_width - 1) {
279                 for (int d = 0; d < depth; ++d) {
280                   crops(b, y, x, d) = extrapolation_value;
281                 }
282                 continue;
283               }
284               const int left_x_index = floorf(in_x);
285               const int right_x_index = ceilf(in_x);
286               const float x_lerp = in_x - left_x_index;
287 
288               for (int d = 0; d < depth; ++d) {
289                 const float top_left(static_cast<float>(
290                     image(b_in, top_y_index, left_x_index, d)));
291                 const float top_right(static_cast<float>(
292                     image(b_in, top_y_index, right_x_index, d)));
293                 const float bottom_left(static_cast<float>(
294                     image(b_in, bottom_y_index, left_x_index, d)));
295                 const float bottom_right(static_cast<float>(
296                     image(b_in, bottom_y_index, right_x_index, d)));
297                 const float top = top_left + (top_right - top_left) * x_lerp;
298                 const float bottom =
299                     bottom_left + (bottom_right - bottom_left) * x_lerp;
300                 crops(b, y, x, d) = top + (bottom - top) * y_lerp;
301               }
302             }
303           } else {  // method == "nearest"
304             for (int x = 0; x < crop_width; ++x) {
305               const float in_x = (crop_width > 1)
306                                      ? x1 * (image_width - 1) + x * width_scale
307                                      : 0.5 * (x1 + x2) * (image_width - 1);
308               if (in_x < 0 || in_x > image_width - 1) {
309                 for (int d = 0; d < depth; ++d) {
310                   crops(b, y, x, d) = extrapolation_value;
311                 }
312                 continue;
313               }
314               const int closest_x_index = roundf(in_x);
315               const int closest_y_index = roundf(in_y);
316               for (int d = 0; d < depth; ++d) {
317                 crops(b, y, x, d) = static_cast<float>(
318                     image(b_in, closest_y_index, closest_x_index, d));
319               }
320             }
321           }
322         }
323       }
324     };
325 
326     // A rough estimation of the cost for each cropped box.
327     double cost_per_pixel =
328         depth * (Eigen::TensorOpCost::AddCost<float>() * 6 +
329                  Eigen::TensorOpCost::MulCost<float>() * 3 +
330                  Eigen::TensorOpCost::CastCost<T, float>() * 4) +
331         (Eigen::TensorOpCost::AddCost<float>() * 2 +
332          Eigen::TensorOpCost::AddCost<float>() * 3);
333     if (method_name == "nearest") {
334       cost_per_pixel = depth * Eigen::TensorOpCost::CastCost<T, float>() +
335                        Eigen::TensorOpCost::AddCost<float>() * 4 +
336                        Eigen::TensorOpCost::MulCost<float>() * 4;
337     }
338     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
339 
340     const DeviceBase::CpuWorkerThreads& worker_threads =
341         *(context->device()->tensorflow_cpu_worker_threads());
342     Shard(worker_threads.num_threads, worker_threads.workers, num_boxes,
343           cost_per_box, CropAndResizePerBox);
344 
345     return true;
346   }
347 };
348 
349 }  // namespace functor
350 
351 template <typename Device, typename T>
352 class CropAndResizeGradImageOp : public AsyncOpKernel {
353  public:
CropAndResizeGradImageOp(OpKernelConstruction * context)354   explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
355       : AsyncOpKernel(context) {
356     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
357     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
358                 errors::InvalidArgument(
359                     "method must be 'bilinear' or 'nearest'", method_));
360   }
361 
ComputeAsync(OpKernelContext * context,DoneCallback done)362   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
363     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
364     const Tensor& grads = context->input(0);
365     // The shape of 'boxes' is [num_boxes, 4].
366     const Tensor& boxes = context->input(1);
367     // The shape of 'box_index' is [num_boxes].
368     const Tensor& box_index = context->input(2);
369     // The shape of 'image_size' is [4].
370     const Tensor& image_size = context->input(3);
371 
372     // Validate input shapes.
373     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
374                       errors::InvalidArgument("grads image must be 4-D",
375                                               grads.shape().DebugString()),
376                       done);
377     const int crop_height = grads.dim_size(1);
378     const int crop_width = grads.dim_size(2);
379     OP_REQUIRES_ASYNC(
380         context, crop_height > 0 && crop_width > 0,
381         errors::InvalidArgument("grads dimensions must be positive"), done);
382     int num_boxes = 0;
383     OP_REQUIRES_OK_ASYNC(
384         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
385     OP_REQUIRES_ASYNC(
386         context, grads.dim_size(0) == num_boxes,
387         errors::InvalidArgument("boxes and grads have incompatible shape"),
388         done);
389 
390     OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
391                       errors::InvalidArgument("image_size must be 1-D",
392                                               image_size.shape().DebugString()),
393                       done);
394     OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
395                       errors::InvalidArgument("image_size must have 4 elements",
396                                               image_size.shape().DebugString()),
397                       done);
398     auto image_size_vec = image_size.vec<int32>();
399     const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
400     const int image_height = internal::SubtleMustCopy(image_size_vec(1));
401     const int image_width = internal::SubtleMustCopy(image_size_vec(2));
402     const int depth = internal::SubtleMustCopy(image_size_vec(3));
403     OP_REQUIRES_ASYNC(
404         context, image_height > 0 && image_width > 0,
405         errors::InvalidArgument("image dimensions must be positive"), done);
406     OP_REQUIRES_ASYNC(
407         context, grads.dim_size(3) == depth,
408         errors::InvalidArgument("image_size and grads are incompatible"), done);
409 
410     // Allocate output tensor.
411     Tensor* output = nullptr;
412     OP_REQUIRES_OK_ASYNC(
413         context,
414         context->allocate_output(
415             0, TensorShape({batch_size, image_height, image_width, depth}),
416             &output),
417         done);
418 
419     auto compute_callback = [this, context, output]() {
420       const Tensor& grads = context->input(0);
421       const Tensor& boxes = context->input(1);
422       const Tensor& box_index = context->input(2);
423       const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
424           context, grads.tensor<float, 4>(), boxes.tensor<float, 2>(),
425           box_index.tensor<int32, 1>(), output->tensor<T, 4>(), method_);
426 
427       if (!status) {
428         context->SetStatus(errors::Internal(
429             "Failed launch CropAndResizeBackpropImage kernel."));
430       }
431     };
432 
433     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
434                                  batch_size, std::move(compute_callback),
435                                  std::move(done));
436   }
437 
438  private:
439   string method_;
440 };
441 
442 // Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice.
443 namespace functor {
444 template <typename T>
445 struct CropAndResizeBackpropImage<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropImage446   bool operator()(const OpKernelContext* context,
447                   typename TTypes<float, 4>::ConstTensor grads,
448                   typename TTypes<float, 2>::ConstTensor boxes,
449                   typename TTypes<int32, 1>::ConstTensor box_index,
450                   typename TTypes<T, 4>::Tensor grads_image,
451                   const string& method_name) {
452     const int batch_size = grads_image.dimension(0);
453     const int image_height = grads_image.dimension(1);
454     const int image_width = grads_image.dimension(2);
455 
456     const int num_boxes = grads.dimension(0);
457     const int crop_height = grads.dimension(1);
458     const int crop_width = grads.dimension(2);
459     const int depth = grads.dimension(3);
460 
461     grads_image.setZero();
462 
463     auto CropAndResizeBackImgPerBox = [&](int64 start_box, int64 limit_box) {
464       for (int b = start_box; b < limit_box; ++b) {
465         const float y1 = boxes(b, 0);
466         const float x1 = boxes(b, 1);
467         const float y2 = boxes(b, 2);
468         const float x2 = boxes(b, 3);
469 
470         const int32 b_in = box_index(b);
471         if (!FastBoundsCheck(b_in, batch_size)) {
472           continue;
473         }
474 
475         const float height_scale =
476             (crop_height > 1)
477                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
478                 : 0;
479         const float width_scale =
480             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
481                              : 0;
482 
483         for (int y = 0; y < crop_height; ++y) {
484           const float in_y = (crop_height > 1)
485                                  ? y1 * (image_height - 1) + y * height_scale
486                                  : 0.5 * (y1 + y2) * (image_height - 1);
487           if (in_y < 0 || in_y > image_height - 1) {
488             continue;
489           }
490           const int top_y_index = floorf(in_y);
491           const int bottom_y_index = ceilf(in_y);
492           const float y_lerp = in_y - top_y_index;
493 
494           for (int x = 0; x < crop_width; ++x) {
495             const float in_x = (crop_width > 1)
496                                    ? x1 * (image_width - 1) + x * width_scale
497                                    : 0.5 * (x1 + x2) * (image_width - 1);
498             if (in_x < 0 || in_x > image_width - 1) {
499               continue;
500             }
501 
502             if (method_name == "bilinear") {
503               const int left_x_index = floorf(in_x);
504               const int right_x_index = ceilf(in_x);
505               const float x_lerp = in_x - left_x_index;
506 
507               for (int d = 0; d < depth; ++d) {
508                 const float dtop = (1 - y_lerp) * grads(b, y, x, d);
509                 grads_image(b_in, top_y_index, left_x_index, d) +=
510                     static_cast<T>((1 - x_lerp) * dtop);
511                 grads_image(b_in, top_y_index, right_x_index, d) +=
512                     static_cast<T>(x_lerp * dtop);
513                 const float dbottom = y_lerp * grads(b, y, x, d);
514                 grads_image(b_in, bottom_y_index, left_x_index, d) +=
515                     static_cast<T>((1 - x_lerp) * dbottom);
516                 grads_image(b_in, bottom_y_index, right_x_index, d) +=
517                     static_cast<T>(x_lerp * dbottom);
518               }
519             } else {  // method_name == "nearest"
520               for (int d = 0; d < depth; ++d) {
521                 int closest_x_index = roundf(in_x);
522                 int closest_y_index = roundf(in_y);
523                 grads_image(b_in, closest_y_index, closest_x_index, d) +=
524                     static_cast<T>(grads(b, y, x, d));
525               }
526             }
527           }
528         }
529       }
530     };
531 
532     // A rough estimation of the cost for each cropped box.
533     // Including calculation cost in the depth loop and pixel loop.
534     const double cost_per_pixel =
535         (method_name == "bilinear"
536              ? depth * (Eigen::TensorOpCost::AddCost<float>() * 7 +
537                         Eigen::TensorOpCost::MulCost<float>() * 6 +
538                         Eigen::TensorOpCost::CastCost<T, float>() * 4) +
539                    Eigen::TensorOpCost::AddCost<float>() * 4
540              : depth * (Eigen::TensorOpCost::AddCost<float>() +
541                         Eigen::TensorOpCost::CastCost<T, float>()) +
542                    Eigen::TensorOpCost::AddCost<float>() * 3);
543 
544     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
545 
546     const DeviceBase::CpuWorkerThreads& worker_threads =
547         *(context->device()->tensorflow_cpu_worker_threads());
548     Shard(worker_threads.num_threads, worker_threads.workers, num_boxes,
549           cost_per_box, CropAndResizeBackImgPerBox);
550 
551     return true;
552   }
553 };
554 
555 }  // namespace functor
556 
557 template <typename Device, typename T>
558 class CropAndResizeGradBoxesOp : public AsyncOpKernel {
559  public:
CropAndResizeGradBoxesOp(OpKernelConstruction * context)560   explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
561       : AsyncOpKernel(context) {
562     string method;
563     OP_REQUIRES_OK(context, context->GetAttr("method", &method));
564     OP_REQUIRES(context, method == "bilinear",
565                 errors::InvalidArgument("method must be 'bilinear'", method));
566   }
567 
ComputeAsync(OpKernelContext * context,DoneCallback done)568   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
569     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
570     const Tensor& grads = context->input(0);
571     // The shape of 'boxes' is [num_boxes, 4].
572     const Tensor& boxes = context->input(2);
573     // The shape of 'box_index' is [num_boxes].
574     const Tensor& box_index = context->input(3);
575     // The shape of 'image' is [batch_size, image_height, image_width, depth].
576     const Tensor& image = context->input(1);
577 
578     // Validate input shapes.
579     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
580                       errors::InvalidArgument("grads image must be 4-D",
581                                               grads.shape().DebugString()),
582                       done);
583     const int crop_height = grads.dim_size(1);
584     const int crop_width = grads.dim_size(2);
585     const int depth = grads.dim_size(3);
586     OP_REQUIRES_ASYNC(
587         context, crop_height > 0 && crop_width > 0,
588         errors::InvalidArgument("grads dimensions must be positive"), done);
589 
590     OP_REQUIRES_ASYNC(context, image.dims() == 4,
591                       errors::InvalidArgument("input image must be 4-D",
592                                               image.shape().DebugString()),
593                       done);
594     const int batch_size = image.dim_size(0);
595     const int image_height = image.dim_size(1);
596     const int image_width = image.dim_size(2);
597     OP_REQUIRES_ASYNC(
598         context, image_height > 0 && image_width > 0,
599         errors::InvalidArgument("image dimensions must be positive"), done);
600     OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
601                       errors::InvalidArgument("image, grads depth differ"),
602                       done);
603 
604     int num_boxes = 0;
605     OP_REQUIRES_OK_ASYNC(
606         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
607 
608     OP_REQUIRES_ASYNC(
609         context, grads.dim_size(0) == num_boxes,
610         errors::InvalidArgument("boxes and grads have incompatible shape"),
611         done);
612 
613     // Allocate output tensor.
614     Tensor* output = nullptr;
615     OP_REQUIRES_OK_ASYNC(
616         context,
617         context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
618         done);
619 
620     auto compute_callback = [context, output]() {
621       const Tensor& grads = context->input(0);
622       const Tensor& image = context->input(1);
623       const Tensor& boxes = context->input(2);
624       const Tensor& box_index = context->input(3);
625       const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
626           context->eigen_device<Device>(), grads.tensor<float, 4>(),
627           image.tensor<T, 4>(), boxes.tensor<float, 2>(),
628           box_index.tensor<int32, 1>(), output->tensor<float, 2>());
629       if (!status) {
630         context->SetStatus(errors::Internal(
631             "Failed launch CropAndResizeBackpropBoxes kernel."));
632       }
633     };
634 
635     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
636                                  batch_size, std::move(compute_callback),
637                                  std::move(done));
638   }
639 };
640 
641 // Partial specialization of CropAndResizeBackpropBoxes functor for a CPUDevice.
642 namespace functor {
643 template <typename T>
644 struct CropAndResizeBackpropBoxes<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropBoxes645   bool operator()(const CPUDevice& d,
646                   typename TTypes<float, 4>::ConstTensor grads,
647                   typename TTypes<T, 4>::ConstTensor image,
648                   typename TTypes<float, 2>::ConstTensor boxes,
649                   typename TTypes<int32, 1>::ConstTensor box_index,
650                   typename TTypes<float, 2>::Tensor grads_boxes) {
651     const int batch_size = image.dimension(0);
652     const int image_height = image.dimension(1);
653     const int image_width = image.dimension(2);
654 
655     const int num_boxes = grads.dimension(0);
656     const int crop_height = grads.dimension(1);
657     const int crop_width = grads.dimension(2);
658     const int depth = grads.dimension(3);
659 
660     grads_boxes.setZero();
661 
662     for (int b = 0; b < num_boxes; ++b) {
663       const float y1 = boxes(b, 0);
664       const float x1 = boxes(b, 1);
665       const float y2 = boxes(b, 2);
666       const float x2 = boxes(b, 3);
667 
668       const int32 b_in = box_index(b);
669       if (!FastBoundsCheck(b_in, batch_size)) {
670         continue;
671       }
672 
673       const float height_ratio =
674           (crop_height > 1)
675               ? static_cast<float>(image_height - 1) / (crop_height - 1)
676               : 0;
677       const float width_ratio =
678           (crop_width > 1)
679               ? static_cast<float>(image_width - 1) / (crop_width - 1)
680               : 0;
681 
682       const float height_scale =
683           (crop_height > 1) ? (y2 - y1) * height_ratio : 0;
684       const float width_scale = (crop_width > 1) ? (x2 - x1) * width_ratio : 0;
685 
686       for (int y = 0; y < crop_height; ++y) {
687         const float in_y = (crop_height > 1)
688                                ? y1 * (image_height - 1) + y * height_scale
689                                : 0.5 * (y1 + y2) * (image_height - 1);
690         if (in_y < 0 || in_y > image_height - 1) {
691           continue;
692         }
693         const int top_y_index = floorf(in_y);
694         const int bottom_y_index = ceilf(in_y);
695         const float y_lerp = in_y - top_y_index;
696 
697         for (int x = 0; x < crop_width; ++x) {
698           const float in_x = (crop_width > 1)
699                                  ? x1 * (image_width - 1) + x * width_scale
700                                  : 0.5 * (x1 + x2) * (image_width - 1);
701           if (in_x < 0 || in_x > image_width - 1) {
702             continue;
703           }
704           const int left_x_index = floorf(in_x);
705           const int right_x_index = ceilf(in_x);
706           const float x_lerp = in_x - left_x_index;
707 
708           for (int d = 0; d < depth; ++d) {
709             const float top_left(
710                 static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
711             const float top_right(
712                 static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
713             const float bottom_left(static_cast<float>(
714                 image(b_in, bottom_y_index, left_x_index, d)));
715             const float bottom_right(static_cast<float>(
716                 image(b_in, bottom_y_index, right_x_index, d)));
717             // Compute the image gradient.
718             float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
719                                  x_lerp * (bottom_right - top_right);
720             float image_grad_x = (1 - y_lerp) * (top_right - top_left) +
721                                  y_lerp * (bottom_right - bottom_left);
722             // Modulate the image gradient with the incoming gradient.
723             const float top_grad = grads(b, y, x, d);
724             image_grad_y *= top_grad;
725             image_grad_x *= top_grad;
726             // dy1, dy2
727             if (crop_height > 1) {
728               grads_boxes(b, 0) +=
729                   image_grad_y * (image_height - 1 - y * height_ratio);
730               grads_boxes(b, 2) += image_grad_y * (y * height_ratio);
731             } else {
732               grads_boxes(b, 0) += image_grad_y * 0.5 * (image_height - 1);
733               grads_boxes(b, 2) += image_grad_y * 0.5 * (image_height - 1);
734             }
735             // dx1, dx2
736             if (crop_width > 1) {
737               grads_boxes(b, 1) +=
738                   image_grad_x * (image_width - 1 - x * width_ratio);
739               grads_boxes(b, 3) += image_grad_x * (x * width_ratio);
740             } else {
741               grads_boxes(b, 1) += image_grad_x * 0.5 * (image_width - 1);
742               grads_boxes(b, 3) += image_grad_x * 0.5 * (image_width - 1);
743             }
744           }
745         }
746       }
747     }
748     return true;
749   }
750 };
751 
752 }  // namespace functor
753 
754 #define REGISTER_KERNEL(T)                                \
755   REGISTER_KERNEL_BUILDER(Name("CropAndResize")           \
756                               .Device(DEVICE_CPU)         \
757                               .TypeConstraint<T>("T")     \
758                               .HostMemory("crop_size"),   \
759                           CropAndResizeOp<CPUDevice, T>); \
760                                                           \
761   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")  \
762                               .Device(DEVICE_CPU)         \
763                               .TypeConstraint<T>("T"),    \
764                           CropAndResizeGradBoxesOp<CPUDevice, T>);
765 
766 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
767 
768 #undef REGISTER_KERNEL
769 
770 #define REGISTER_KERNEL(T)                               \
771   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
772                               .Device(DEVICE_CPU)        \
773                               .TypeConstraint<T>("T")    \
774                               .HostMemory("image_size"), \
775                           CropAndResizeGradImageOp<CPUDevice, T>);
776 
777 TF_CALL_half(REGISTER_KERNEL);
778 TF_CALL_float(REGISTER_KERNEL);
779 TF_CALL_double(REGISTER_KERNEL);
780 
781 #undef REGISTER_KERNEL
782 
783 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
784 
785 // Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
786 namespace functor {
787 template <>
788 void CheckValidBoxIndexHelper<GPUDevice>::operator()(
789     const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
790     int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
791 extern template struct CheckValidBoxIndexHelper<GPUDevice>;
792 }  // namespace functor
793 
794 namespace {
795 
796 // Specialization of CheckValidBoxIndex for a GPUDevice.
797 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)798 inline void RunIfBoxIndexIsValid<GPUDevice>(
799     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
800     int batch_size, const Callback& compute, const Callback& done) {
801   const int num_boxes = box_index.dimension(0);
802   if (num_boxes == 0) {
803     compute();
804     done();
805     return;
806   }
807 
808   Tensor isvalid_dev_tensor;
809   OP_REQUIRES_OK_ASYNC(
810       context,
811       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
812                              &isvalid_dev_tensor),
813       done);
814   typename TTypes<bool, 0>::Tensor isvalid_dev =
815       isvalid_dev_tensor.tensor<bool, 0>();
816 
817   // Run the actual box check on the device.
818   functor::CheckValidBoxIndexHelper<GPUDevice>()(
819       context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
820 
821   // Copy the result back to the host.
822   auto* stream = context->op_device_context()->stream();
823   OP_REQUIRES_ASYNC(context, stream,
824                     errors::Internal("No GPU stream available."), done);
825   Tensor isvalid_host_tensor;
826   // Use pinned host memory on the host to avoid unnecessary
827   // synchronization.
828   AllocatorAttributes alloc_attr;
829   alloc_attr.set_on_host(true);
830   alloc_attr.set_gpu_compatible(true);
831   OP_REQUIRES_OK_ASYNC(
832       context,
833       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
834                              &isvalid_host_tensor, alloc_attr),
835       done);
836   se::DeviceMemoryBase wrapped(isvalid_dev.data(), sizeof(bool));
837   const bool status =
838       stream
839           ->ThenMemcpy(
840               isvalid_host_tensor.scalar<bool>().data() /* destination */,
841               wrapped /* source */, sizeof(bool))
842           .ok();
843   OP_REQUIRES_ASYNC(
844       context, status,
845       errors::Internal("Failed to launch copy of isvalid from device to host."),
846       done);
847 
848   // We capture both temporary tensors to prevent them from being deallocated
849   // when ComputeAsync returns and before the closure runs.
850   TensorReference isvalid_dev_ref(isvalid_dev_tensor);
851   auto wrapped_callback = [context, isvalid_host_tensor, isvalid_dev_ref,
852                            compute, done]() {
853     auto stream = context->op_device_context()->stream();
854     ScopedActivateExecutorContext scoped_activation{stream->parent()};
855     const bool isvalid = isvalid_host_tensor.scalar<bool>()();
856     isvalid_dev_ref.Unref();
857     OP_REQUIRES_ASYNC(
858         context, isvalid,
859         errors::OutOfRange("box_index has values outside [0, batch_size)"),
860         done);
861     compute();
862     done();
863   };
864 
865   context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
866       stream, wrapped_callback);
867 }
868 
869 }  // namespace
870 
871 #define REGISTER_KERNEL(T)                                         \
872   REGISTER_KERNEL_BUILDER(Name("CropAndResize")                    \
873                               .Device(DEVICE_GPU)                  \
874                               .TypeConstraint<T>("T")              \
875                               .HostMemory("crop_size"),            \
876                           CropAndResizeOp<GPUDevice, T>);          \
877                                                                    \
878   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage")           \
879                               .Device(DEVICE_GPU)                  \
880                               .TypeConstraint<T>("T")              \
881                               .HostMemory("image_size"),           \
882                           CropAndResizeGradImageOp<GPUDevice, T>); \
883                                                                    \
884   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")           \
885                               .Device(DEVICE_GPU)                  \
886                               .TypeConstraint<T>("T"),             \
887                           CropAndResizeGradBoxesOp<GPUDevice, T>);
888 
889 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
890 
891 #undef REGISTER_KERNEL
892 
893 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
894 
895 }  // namespace tensorflow
896