1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <numeric>
17 #include <vector>
18 
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/array4d.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/framework/kernel_def_builder.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/lib/math/math_util.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 using xla::XlaOp;
45 
46 // Calculates the bilinear weight tensor, given basis ratio (px, py) of the
47 // sampling position:
48 //    W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
49 // 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2].
50 //
51 // The returned tensor has dimensions [batch, dim_0, ... dim_n, 4].
BilinearWeights(XlaOpKernelContext * ctx,XlaOp ratio,const TensorShape warp_shape,xla::PrimitiveType xla_type)52 XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio,
53                       const TensorShape warp_shape,
54                       xla::PrimitiveType xla_type) {
55   auto first_term = xla::ConstantR2<float>(
56       ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}});
57   first_term = xla::ConvertElementType(first_term, xla_type);
58 
59   auto warp_dims = warp_shape.dim_sizes();
60   std::vector<int64> broadcast_dims(warp_dims.begin(), warp_dims.end() - 1);
61   broadcast_dims.push_back(4);
62   broadcast_dims.push_back(2);
63 
64   const int64 broadcast_dims_size = broadcast_dims.size();
65 
66   std::vector<int64> last_two_dims_indices = {(broadcast_dims_size - 2),
67                                               (broadcast_dims_size - 1)};
68 
69   auto broadcast_first_term =
70       xla::BroadcastInDim(first_term, broadcast_dims, last_two_dims_indices);
71 
72   // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n,
73   // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the
74   // [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last
75   // dimension.
76   std::vector<int64> ratio_broadcast_indices(broadcast_dims.size());
77   std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0);
78   ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2);
79 
80   auto broadcast_ratio =
81       xla::BroadcastInDim(ratio, broadcast_dims, ratio_broadcast_indices);
82 
83   auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio;
84 
85   // Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to
86   // flip the signs of the second and the third term.
87   auto sign_change = xla::ConstantR2<float>(
88       ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}});
89   sign_change = xla::ConvertElementType(sign_change, xla_type);
90 
91   auto broadcast_sign_change =
92       xla::BroadcastInDim(sign_change, broadcast_dims, last_two_dims_indices);
93 
94   auto flipped = first_term_subtract_weights * broadcast_sign_change;
95 
96   // Build up the final bilinear weight tensor by multiply reduction, which
97   // gives:
98   //    [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
99   // for each 4 neighboring pixels where px and py are the weight of the target
100   // pixel we are sampling from.
101   return xla::Reduce(
102       flipped, xla::One(ctx->builder(), xla_type),
103       xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()),
104       {broadcast_dims_size - 1});
105 }
106 
107 // Concatenates the batch indices to the (x, y) coordinate indices.
108 // This is done by first creating an Iota tensor that represents the current
109 // batch it is in, then concatenate with the givin (coordinate) indices.
110 //
111 // The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where
112 // the last dimension of size 3 in turn is [batch_number, x, y].
113 // The [batch_number, x, y] dimension is needed because the indices
114 // [x,y] alone cannot allow the xla::Gather operation to gather from the input
115 // data, which is of dimension [batch, height(y), width(x), channel] with
116 // 'batch' being the first dimension.
ConcatenateIota(xla::XlaBuilder * b,XlaOp indices,const TensorShape & warp_shape)117 XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices,
118                       const TensorShape& warp_shape) {
119   // We need to create an iota tensor with the same batch dimension.
120   std::vector<int64> dimensions;
121   for (auto dim : warp_shape) {
122     dimensions.push_back(dim.size);
123   }
124   // Except the last dimension, which is of size 1.
125   dimensions.back() = 1;
126 
127   auto batch_indices =
128       xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions),
129                 /*iota_dimension=*/0);
130 
131   return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1);
132 }
133 
134 // Gathers the 2x2 neighbors of the input starting_indices, and return a
135 // tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels].
136 // 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last
137 // dimension of size 3 is (batch_no, x, y).
Gather2by2Neighbors(xla::XlaBuilder * b,XlaOp data,XlaOp gather_indices,int64 data_channels,int warp_dims)138 XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices,
139                           int64 data_channels, int warp_dims) {
140   xla::GatherDimensionNumbers gather_dim_numbers;
141   const int64 neighbor_data_dimensions = warp_dims + 2;
142   // Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2,
143   // data_channels], the offset dimensions for Gather is the last 3 dimensions.
144   gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3);
145   gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2);
146   gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1);
147   // The last dimension of 'gather_indices' is the starting indices for gather.
148   gather_dim_numbers.set_index_vector_dim(warp_dims - 1);
149   gather_dim_numbers.add_collapsed_slice_dims(0);
150   gather_dim_numbers.add_start_index_map(0);
151   // Since input is of dimension [batch, height(y), width(x), channel], and warp
152   // is of dimension [batch, x, y], the ordering of x, y here needs to be
153   // swapped when gathering.
154   gather_dim_numbers.add_start_index_map(2);
155   gather_dim_numbers.add_start_index_map(1);
156   // Data dimensions are [batch, x, y, channel].
157   // Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels].
158   auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers,
159                                     /*slice_sizes=*/{1, 2, 2, data_channels});
160   // Collapse the ...,2,2,... dimensions into ...,4,...
161   return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims});
162 }
163 
164 // Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the
165 // resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels].
166 // This function can also be seen as the inverse of 'Gather2by2Neighbors'.
ScatterToGradData(XlaOpKernelContext * ctx,XlaOp grad_data,XlaOp indices,XlaOp updates,int64 warp_dims,xla::PrimitiveType xla_type)167 XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices,
168                         XlaOp updates, int64 warp_dims,
169                         xla::PrimitiveType xla_type) {
170   xla::ScatterDimensionNumbers scatter_dim_numbers;
171   const int64 neighbor_data_dimensions = warp_dims + 2;
172   // Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2,
173   // data_channels], the update window dimensions is the last 3 dimensions.
174   scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3);
175   scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2);
176   scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1);
177   scatter_dim_numbers.set_index_vector_dim(warp_dims - 1);
178 
179   scatter_dim_numbers.add_inserted_window_dims(0);
180   scatter_dim_numbers.add_scatter_dims_to_operand_dims(0);
181   // Since input is of dimension [batch, height(y), width(x), channel], and warp
182   // is of dimension [batch, x, y], the ordering of x, y here needs to be
183   // swapped when scattering.
184   scatter_dim_numbers.add_scatter_dims_to_operand_dims(2);
185   scatter_dim_numbers.add_scatter_dims_to_operand_dims(1);
186 
187   return xla::Scatter(grad_data, indices, updates,
188                       xla::CreateScalarAddComputation(xla_type, ctx->builder()),
189                       scatter_dim_numbers);
190 }
191 
192 // Bounds samples to 0 if the warp image indices are out of the (-1, image_size)
193 // bound.
194 // The resulting dimension is given by 'result_dims'.
BoundSamples(XlaOpKernelContext * ctx,XlaOp warp,xla::PrimitiveType warp_type,TensorShape warp_shape,std::vector<int64> result_dims,std::vector<int64> broadcasted_dims,int64 last_warp_dim,xla::Shape data_shape,XlaOp sample)195 XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp,
196                    xla::PrimitiveType warp_type, TensorShape warp_shape,
197                    std::vector<int64> result_dims,
198                    std::vector<int64> broadcasted_dims, int64 last_warp_dim,
199                    xla::Shape data_shape, XlaOp sample) {
200   auto is_gt_minus_one =
201       xla::Gt(warp,
202               xla::ConvertElementType(
203                   xla::ConstantR1<float>(ctx->builder(), {-1, -1}), warp_type),
204               /*broadcast_dimensions=*/{warp_shape.dims() - 1});
205   auto is_lt_image_size = xla::Lt(
206       warp,
207       xla::ConvertElementType(
208           xla::ConstantR1<float>(
209               ctx->builder(),
210               {/*width=*/static_cast<float>(data_shape.dimensions(2)),
211                /*height=*/static_cast<float>(data_shape.dimensions(1))}),
212           warp_type),
213       /*broadcast_dimensions=*/{warp_shape.dims() - 1});
214 
215   auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size);
216   // Reduce along last dimension. The resulting dimension is:
217   // [batch, dim_0, ...dim_n].
218   auto is_in_bound = xla::Reduce(
219       is_in_bound_padded_x_y, xla::ConstantR0<bool>(ctx->builder(), true),
220       xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()),
221       {last_warp_dim});
222 
223   // Broadcast 'is_in_bound' to the same dimension as 'result_dims'.
224   auto broadcasted_is_in_bound =
225       xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims);
226 
227   // Set out of bound samples to zero.
228   auto zeros =
229       xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims);
230   return xla::Select(broadcasted_is_in_bound, sample, zeros);
231 }
232 
233 // Build computation the backprop into input 'data'.
234 // Where input:
235 // grad_output is of dimension [batch, dim_0, ...dim_n, channel]
236 // ratio is of dimension [batch, dim_0, ...dim_n, 2]
237 // gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
238 // data_shape is of dimension [batch, x(width), y(height), channel]
239 //
240 // Output:
241 // scatter-add to each 2x2 grad_data neighbor:
242 //  grad_data[fx, fy, chan] += output_grad * dx * dy
243 //  grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy
244 //  grad_data[fx, cy, chan] += output_grad * dx * (1 - dy)
245 //  grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy)
246 // where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their
247 // contribution is 0 to 'grad_data'.
CalculateGradData(XlaOpKernelContext * ctx,XlaOp grad_output,XlaOp ratio,XlaOp gather_indices,XlaOp warp,xla::PrimitiveType warp_type,TensorShape warp_shape,int64 last_warp_dim,int64 data_channels,xla::Shape data_shape)248 XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
249                         XlaOp gather_indices, XlaOp warp,
250                         xla::PrimitiveType warp_type, TensorShape warp_shape,
251                         int64 last_warp_dim, int64 data_channels,
252                         xla::Shape data_shape) {
253   // Weights tensor has dimension [batch, dim_0, ... dim_n, 4].
254   auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type);
255 
256   auto warp_dims = warp_shape.dim_sizes();
257   std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
258                                                  warp_dims.end() - 1);
259 
260   std::vector<int64> reshaped_weights_dims = warp_dims_without_last_dims;
261   // Reshape the last dimension of size 4 to two dimensions [2, 2].
262   reshaped_weights_dims.push_back(2);
263   reshaped_weights_dims.push_back(2);
264   std::vector<int64> reshape_dims(warp_shape.dims());
265   std::iota(reshape_dims.begin(), reshape_dims.end(), 0);
266   // The dimension is [batch, dim_0,..., dim_n, 2, 2].
267   auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims,
268                                        /*new_sizes=*/reshaped_weights_dims);
269 
270   std::vector<int64> weights_with_channels_dims = reshaped_weights_dims;
271   weights_with_channels_dims.push_back(data_channels);
272   std::vector<int64> reshaped_weights_indices(reshaped_weights_dims.size());
273   std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(),
274             0);
275 
276   // Set out of bound weights to 0.
277   // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2].
278   std::vector<int64> reshaped_result_dims(warp_dims.begin(),
279                                           warp_dims.end() - 1);
280   reshaped_result_dims.push_back(2);
281   reshaped_result_dims.push_back(2);
282   std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
283   std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
284   reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape,
285                                   reshaped_result_dims, broadcasted_dims,
286                                   last_warp_dim, data_shape, reshaped_weights);
287 
288   // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel].
289   auto broadcast_reshaped_weights = xla::BroadcastInDim(
290       reshaped_weights, weights_with_channels_dims, reshaped_weights_indices);
291 
292   std::vector<int64> grad_output_indices(warp_dims_without_last_dims.size());
293   std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0);
294   grad_output_indices.push_back(weights_with_channels_dims.size() - 1);
295   XlaOp broadcast_grad_output = xla::BroadcastInDim(
296       grad_output, weights_with_channels_dims, grad_output_indices);
297 
298   auto grad_output_multiply_weights =
299       broadcast_grad_output * broadcast_reshaped_weights;
300 
301   auto grad_data = xla::ConstantLiteral(
302       ctx->builder(), xla::Literal::CreateFromShape(data_shape));
303 
304   // Pad grad data then slice it back.
305   //
306   // After left and right column 0-padding, the new dimension of padded data
307   // will be [batch, x+2, y+2, channel].
308   auto padded_grad_data =
309       xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type),
310                xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
311 
312   auto shifting_value = xla::ConstantR1<int32>(
313       ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
314   auto shifted_gather_indices =
315       xla::Add(gather_indices, shifting_value, {last_warp_dim});
316 
317   auto updated_grad_data = ScatterToGradData(
318       ctx, padded_grad_data, shifted_gather_indices,
319       grad_output_multiply_weights, warp_shape.dims(), warp_type);
320 
321   const int64 batch_size = data_shape.dimensions(0);
322   const int64 width = data_shape.dimensions(1);
323   const int64 height = data_shape.dimensions(2);
324   // Slice out the result accounting for the padding.
325   return xla::Slice(
326       updated_grad_data, /*start_indices=*/{0, 1, 1, 0},
327       /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels},
328       /*strides=*/{1, 1, 1, 1});
329 }
330 
331 // Build computation for the backprop into input 'warp'.
332 // Where input:
333 //  warp is of dimension [batch, dim_0, ...dim_n, 2]
334 //  grad_output is of dimension [batch, dim_0, ...dim_n, channel]
335 //  ratio is of dimension [batch, dim_0, ...dim_n, 2]
336 //  gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last
337 //  dimension of size 3 is for {batch, x(width), y(height)}.
338 //  data is of dimension [batch, x, y, channel]
339 //
340 // Output (simplified by ignoring the batch dimensions):
341 // Since the forward path has:
342 //    output = dot(weights * neighbors)
343 // The backprop into warp will therefore be:
344 //    grad_warp = output_grad * d_output / d_warp
345 //              = output_grad * (d_weights / d_warp * neighbors + d_neighbors /
346 //              d_warp * weight)
347 // Where:
348 //    d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py]
349 //    d_weights / d_warp_y = [-(1 - px), -px, (1-px), px]
350 // and
351 //    d_neighbors / d_warp_x = 0
352 //
353 // Therefore:
354 //    grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy)
355 //    grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy)
356 //
357 // where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the
358 // bottom right corner in a 2x2 neighborhood.
CalculateGradWarp(XlaOpKernelContext * ctx,XlaOp grad_output,XlaOp ratio,XlaOp gather_indices,XlaOp data,TensorShape warp_shape,int64 data_channels,xla::PrimitiveType data_type,xla::Shape data_shape)359 XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
360                         XlaOp gather_indices, XlaOp data,
361                         TensorShape warp_shape, int64 data_channels,
362                         xla::PrimitiveType data_type, xla::Shape data_shape) {
363   auto warp_dims = warp_shape.dim_sizes();
364   std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
365                                                  warp_dims.end() - 1);
366 
367   // With dimension [batch, dim_0, ...dim_n, 4]
368   std::vector<int64> neighbor_broadcast_dims = warp_dims_without_last_dims;
369   neighbor_broadcast_dims.push_back(4);
370 
371   // With dimension [batch, dim_0, ...dim_n, 4]
372   auto neighbor_broadcast_shape =
373       xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims);
374 
375   const int64 last_warp_dim = warp_shape.dims() - 1;
376 
377   // Pad data with 0, before gathering such that 0 will be returned for samples
378   // in the range of (-1, 0) or (image_dimension-1, image_dimension).
379   // After left and right column 0-padding, the new dimension of padded data
380   // will be [batch, x+2, y+2, channel].
381   auto padded_data =
382       xla::Pad(data, xla::Zero(ctx->builder(), data_type),
383                xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
384 
385   auto shifting_value = xla::ConstantR1<int32>(
386       ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
387   auto shifted_gather_indices =
388       xla::Add(gather_indices, shifting_value, {last_warp_dim});
389 
390   // The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
391   auto neighbors_data =
392       Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices,
393                           data_channels, warp_shape.dims());
394 
395   // Since we will be creating the dot product of:
396   //  lhs: [batch, dim_0, ...dim_n, 4]
397   // and
398   //  rhs: [batch, dim_0, ...dim_n, 4, data_channels]
399   // we choose the last dimension of lhs and the second last dimension of rhs,
400   // with size 4, as the contracting dimension.
401   xla::DotDimensionNumbers dot_dims;
402   for (int i = 0; i < warp_shape.dims() - 1; ++i) {
403     dot_dims.add_lhs_batch_dimensions(i);
404     dot_dims.add_rhs_batch_dimensions(i);
405   }
406   dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
407   dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
408 
409   // img_cxcy - img_fxcy
410   auto bottom_right_minus_bottom_left = xla::DotGeneral(
411       xla::BroadcastInDim(
412           xla::ConvertElementType(
413               xla::ConstantR1<float>(ctx->builder(), {0, 0, -1, 1}), data_type),
414           neighbor_broadcast_dims, {last_warp_dim}),
415       neighbors_data, dot_dims, /*precision_config=*/nullptr);
416 
417   // img_cxfy - img_fxfy
418   auto top_right_minus_top_left = xla::DotGeneral(
419       xla::BroadcastInDim(
420           xla::ConvertElementType(
421               xla::ConstantR1<float>(ctx->builder(), {-1, 1, 0, 0}), data_type),
422           neighbor_broadcast_dims, {last_warp_dim}),
423       neighbors_data, dot_dims, /*precision_config=*/nullptr);
424 
425   // img_cxcy - img_cxfy
426   auto bottom_right_minus_top_right = xla::DotGeneral(
427       xla::BroadcastInDim(
428           xla::ConvertElementType(
429               xla::ConstantR1<float>(ctx->builder(), {0, -1, 0, 1}), data_type),
430           neighbor_broadcast_dims, {last_warp_dim}),
431       neighbors_data, dot_dims, /*precision_config=*/nullptr);
432 
433   // img_fxcy - img_fxfy
434   auto bottom_left_minus_top_left = xla::DotGeneral(
435       xla::BroadcastInDim(
436           xla::ConvertElementType(
437               xla::ConstantR1<float>(ctx->builder(), {-1, 0, 1, 0}), data_type),
438           neighbor_broadcast_dims, {last_warp_dim}),
439       neighbors_data, dot_dims, /*precision_config=*/nullptr);
440 
441   // Slice out x and y.
442   auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1,
443                                   /*stride=*/1, /*dimno=*/last_warp_dim);
444   auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2,
445                                   /*stride=*/1, /*dimno=*/last_warp_dim);
446 
447   // Build 1 - y and 1 - x.
448   auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y;
449   auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x;
450 
451   auto x_before_reduce =
452       grad_output * weight_y * bottom_right_minus_bottom_left +
453       one_minus_y * top_right_minus_top_left;
454 
455   std::vector<int64> reshaped_sizes = warp_dims_without_last_dims;
456   reshaped_sizes.push_back(1);
457 
458   std::vector<int64> reshaped_dims(warp_dims_without_last_dims.size());
459   std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0);
460 
461   // Reduce-add along the channel dimension.
462   auto x_result =
463       xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type),
464                   xla::CreateScalarAddComputation(data_type, ctx->builder()),
465                   {last_warp_dim});
466   // Reshape before concatenating with y values.
467   XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes);
468 
469   auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right +
470                          one_minus_x * bottom_left_minus_top_left;
471   // Reduce-add along the channel dimension.
472   auto y_result =
473       xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type),
474 
475                   xla::CreateScalarAddComputation(data_type, ctx->builder()),
476                   {last_warp_dim});
477   XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes);
478 
479   return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y},
480                           last_warp_dim);
481 }
482 
483 class ResamplerOp : public XlaOpKernel {
484  public:
ResamplerOp(OpKernelConstruction * ctx)485   explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
486 
Compile(XlaOpKernelContext * ctx)487   void Compile(XlaOpKernelContext* ctx) override {
488     TensorShape data_shape = ctx->InputShape("data");
489     OP_REQUIRES(ctx, data_shape.dims() == 4,
490                 errors::InvalidArgument("data must be 4-dimensional",
491                                         data_shape.DebugString()));
492     const int64 data_channels = data_shape.dim_size(3);
493     xla::PrimitiveType data_type = ctx->input_xla_type(0);
494 
495     TensorShape warp_shape = ctx->InputShape("warp");
496     OP_REQUIRES(ctx, warp_shape.dims() >= 2,
497                 errors::InvalidArgument("warp must be at least 2-dimensional",
498                                         warp_shape.DebugString()));
499     for (int size : warp_shape.dim_sizes()) {
500       OP_REQUIRES(ctx, size > 0,
501                   errors::InvalidArgument("warp sizes must be positive, got [",
502                                           size, "]"));
503     }
504     const int64 last_warp_dim = warp_shape.dims() - 1;
505     // Last dimension of warp shape must be of size 2.
506     OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
507                 errors::InvalidArgument(
508                     "the last dimension of warp must be exactly size 2."));
509     xla::PrimitiveType warp_type = ctx->input_xla_type(1);
510 
511     XlaOp data = ctx->Input("data");
512     XlaOp warp = ctx->Input("warp");
513 
514     // Find the coordinates of the top left corner for the 2x2 region to be
515     // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the
516     // last dimension of size 2 in turn is [x, y].
517     XlaOp top_left = xla::ConvertElementType(warp, xla::S32);
518 
519     auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
520 
521     // The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
522     auto neighbors_data = Gather2by2Neighbors(
523         ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
524 
525     // Dimensions are [batch, dim_0, ... dim_n, 2].
526     XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type);
527 
528     // Obtain the bilinear blending weights, the dimension is [batch, dim_0,
529     // ...dim_n, 4].
530     auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type);
531 
532     // Since we will be creating the dot product of:
533     //  lhs: [batch, dim_0, ...dim_n, 4]
534     // and
535     //  rhs: [batch, dim_0, ...dim_n, 4, data_channels]
536     // we choose the last dimension of lhs and the second last dimension of rhs,
537     // with size 4, as the contracting dimension.
538     xla::DotDimensionNumbers dot_dims;
539     for (int i = 0; i < warp_shape.dims() - 1; ++i) {
540       dot_dims.add_lhs_batch_dimensions(i);
541       dot_dims.add_rhs_batch_dimensions(i);
542     }
543     dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
544     dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
545 
546     // The dimension is [batch, dim_0, ...dim_n, data_channels].
547     auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims,
548                                           /*precision_config=*/nullptr);
549 
550     // Handle out of boundary cases by constructing a predicate mask array based
551     // on the in-bound condition, and output 0 for the blended pixel value if
552     // out-bound. The dimension is the same as top_left: [batch, dim_0,
553     // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate.
554 
555     auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp));
556 
557     auto is_lt_image_size = xla::Lt(
558         warp,
559         xla::ConvertElementType(
560             xla::ConstantR1<float>(
561                 ctx->builder(),
562                 {/*width=*/static_cast<float>(data_shape.dim_size(2) - 1),
563                  /*height=*/static_cast<float>(data_shape.dim_size(1) - 1)}),
564             warp_type),
565         /*broadcast_dimensions=*/{warp_shape.dims() - 1});
566 
567     auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size);
568     // Reduce along last dimension. The resulting dimension is:
569     // [batch, dim_0, ...dim_n].
570     auto is_in_bound = xla::Reduce(
571         is_in_bound_x_y, xla::ConstantR0<bool>(ctx->builder(), true),
572         xla::CreateScalarAndComputation(xla::PrimitiveType::PRED,
573                                         ctx->builder()),
574         {last_warp_dim});
575 
576     // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which
577     // is the dimension of the result:
578     //  [batch, dim_0, ...dim_n, data_channels].
579     auto warp_dims = warp_shape.dim_sizes();
580     std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1);
581     result_dims.push_back(data_channels);
582 
583     std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
584     std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
585     auto broadcasted_is_in_bound =
586         xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims);
587 
588     // Set out of bound samples to zero.
589     auto zeros =
590         xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims);
591     auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros);
592 
593     ctx->SetOutput(0, result);
594   }
595 };
596 
597 REGISTER_XLA_OP(Name("Resampler"), ResamplerOp);
598 
599 class ResamplerGradOp : public XlaOpKernel {
600  public:
ResamplerGradOp(OpKernelConstruction * ctx)601   explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
602     DataType output_dtype;
603     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
604   }
605 
606   // TODO(b/112295522): note that sampling from image boundary is not currently
607   // being handled properly.
Compile(XlaOpKernelContext * ctx)608   void Compile(XlaOpKernelContext* ctx) override {
609     TensorShape data_shape_tf = ctx->InputShape("data");
610     OP_REQUIRES(ctx, data_shape_tf.dims() == 4,
611                 errors::InvalidArgument("data must be 4-dimensional",
612                                         data_shape_tf.DebugString()));
613     const int64 data_channels = data_shape_tf.dim_size(3);
614     xla::PrimitiveType data_type = ctx->input_xla_type(0);
615 
616     TensorShape warp_shape = ctx->InputShape("warp");
617     OP_REQUIRES(ctx, warp_shape.dims() >= 2,
618                 errors::InvalidArgument("warp must be at least 2-dimensional",
619                                         warp_shape.DebugString()));
620     for (int size : warp_shape.dim_sizes()) {
621       OP_REQUIRES(ctx, size > 0,
622                   errors::InvalidArgument("warp sizes must be positive, got [",
623                                           size, "]"));
624     }
625     // Last dimension of warp shape must be of size 2.
626     const int64 last_warp_dim = warp_shape.dims() - 1;
627     OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
628                 errors::InvalidArgument(
629                     "the last dimension of warp must be exactly size 2."));
630     xla::PrimitiveType warp_type = ctx->input_xla_type(1);
631 
632     TensorShape output_grad_shape = ctx->InputShape("grad_output");
633     OP_REQUIRES(
634         ctx, output_grad_shape.dims() >= 2,
635         errors::InvalidArgument("output_grad must be at least 2-dimensional",
636                                 output_grad_shape.DebugString()));
637 
638     // Dimensions are [batch, x, y, channel].
639     XlaOp data = ctx->Input("data");
640     xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf);
641 
642     // Dimensions are [batch, dim_0, ...dim_n, 2].
643     XlaOp warp = ctx->Input("warp");
644     // Dimensions are [batch, dim_0, ...dim_n, channel].
645     XlaOp grad_output = ctx->Input("grad_output");
646 
647     // Find the top left corner coordinate for the region to be sampled from.
648     // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension
649     // of size 2 in turn is [x, y].
650     XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32);
651 
652     // Dimensions are [batch, dim_0, ... dim_n, 2].
653     XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type);
654 
655     // Indices for gathering neighboring pixels.
656     auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
657 
658     auto grad_data = CalculateGradData(
659         ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape,
660         last_warp_dim, data_channels, data_shape);
661 
662     auto grad_warp =
663         CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data,
664                           warp_shape, data_channels, data_type, data_shape);
665     auto warp_dims = warp_shape.dim_sizes();
666     std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1);
667     result_dims.push_back(2);
668     std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
669     std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
670     auto grad_warp_bounded =
671         BoundSamples(ctx, warp, warp_type, warp_shape, result_dims,
672                      broadcasted_dims, last_warp_dim, data_shape, grad_warp);
673 
674     ctx->SetOutput(0, grad_data);
675     ctx->SetOutput(1, grad_warp_bounded);
676   }
677 };
678 
679 REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp);
680 
681 }  // namespace
682 }  // namespace tensorflow
683