1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <numeric>
17 #include <vector>
18
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/array4d.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/framework/kernel_def_builder.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/lib/math/math_util.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace tensorflow {
42 namespace {
43
44 using xla::XlaOp;
45
46 // Calculates the bilinear weight tensor, given basis ratio (px, py) of the
47 // sampling position:
48 // W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
49 // 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2].
50 //
51 // The returned tensor has dimensions [batch, dim_0, ... dim_n, 4].
BilinearWeights(XlaOpKernelContext * ctx,XlaOp ratio,const TensorShape warp_shape,xla::PrimitiveType xla_type)52 XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio,
53 const TensorShape warp_shape,
54 xla::PrimitiveType xla_type) {
55 auto first_term = xla::ConstantR2<float>(
56 ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}});
57 first_term = xla::ConvertElementType(first_term, xla_type);
58
59 auto warp_dims = warp_shape.dim_sizes();
60 std::vector<int64> broadcast_dims(warp_dims.begin(), warp_dims.end() - 1);
61 broadcast_dims.push_back(4);
62 broadcast_dims.push_back(2);
63
64 const int64 broadcast_dims_size = broadcast_dims.size();
65
66 std::vector<int64> last_two_dims_indices = {(broadcast_dims_size - 2),
67 (broadcast_dims_size - 1)};
68
69 auto broadcast_first_term =
70 xla::BroadcastInDim(first_term, broadcast_dims, last_two_dims_indices);
71
72 // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n,
73 // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the
74 // [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last
75 // dimension.
76 std::vector<int64> ratio_broadcast_indices(broadcast_dims.size());
77 std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0);
78 ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2);
79
80 auto broadcast_ratio =
81 xla::BroadcastInDim(ratio, broadcast_dims, ratio_broadcast_indices);
82
83 auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio;
84
85 // Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to
86 // flip the signs of the second and the third term.
87 auto sign_change = xla::ConstantR2<float>(
88 ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}});
89 sign_change = xla::ConvertElementType(sign_change, xla_type);
90
91 auto broadcast_sign_change =
92 xla::BroadcastInDim(sign_change, broadcast_dims, last_two_dims_indices);
93
94 auto flipped = first_term_subtract_weights * broadcast_sign_change;
95
96 // Build up the final bilinear weight tensor by multiply reduction, which
97 // gives:
98 // [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
99 // for each 4 neighboring pixels where px and py are the weight of the target
100 // pixel we are sampling from.
101 return xla::Reduce(
102 flipped, xla::One(ctx->builder(), xla_type),
103 xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()),
104 {broadcast_dims_size - 1});
105 }
106
107 // Concatenates the batch indices to the (x, y) coordinate indices.
108 // This is done by first creating an Iota tensor that represents the current
109 // batch it is in, then concatenate with the givin (coordinate) indices.
110 //
111 // The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where
112 // the last dimension of size 3 in turn is [batch_number, x, y].
113 // The [batch_number, x, y] dimension is needed because the indices
114 // [x,y] alone cannot allow the xla::Gather operation to gather from the input
115 // data, which is of dimension [batch, height(y), width(x), channel] with
116 // 'batch' being the first dimension.
ConcatenateIota(xla::XlaBuilder * b,XlaOp indices,const TensorShape & warp_shape)117 XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices,
118 const TensorShape& warp_shape) {
119 // We need to create an iota tensor with the same batch dimension.
120 std::vector<int64> dimensions;
121 for (auto dim : warp_shape) {
122 dimensions.push_back(dim.size);
123 }
124 // Except the last dimension, which is of size 1.
125 dimensions.back() = 1;
126
127 auto batch_indices =
128 xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions),
129 /*iota_dimension=*/0);
130
131 return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1);
132 }
133
134 // Gathers the 2x2 neighbors of the input starting_indices, and return a
135 // tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels].
136 // 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last
137 // dimension of size 3 is (batch_no, x, y).
Gather2by2Neighbors(xla::XlaBuilder * b,XlaOp data,XlaOp gather_indices,int64 data_channels,int warp_dims)138 XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices,
139 int64 data_channels, int warp_dims) {
140 xla::GatherDimensionNumbers gather_dim_numbers;
141 const int64 neighbor_data_dimensions = warp_dims + 2;
142 // Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2,
143 // data_channels], the offset dimensions for Gather is the last 3 dimensions.
144 gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3);
145 gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2);
146 gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1);
147 // The last dimension of 'gather_indices' is the starting indices for gather.
148 gather_dim_numbers.set_index_vector_dim(warp_dims - 1);
149 gather_dim_numbers.add_collapsed_slice_dims(0);
150 gather_dim_numbers.add_start_index_map(0);
151 // Since input is of dimension [batch, height(y), width(x), channel], and warp
152 // is of dimension [batch, x, y], the ordering of x, y here needs to be
153 // swapped when gathering.
154 gather_dim_numbers.add_start_index_map(2);
155 gather_dim_numbers.add_start_index_map(1);
156 // Data dimensions are [batch, x, y, channel].
157 // Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels].
158 auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers,
159 /*slice_sizes=*/{1, 2, 2, data_channels});
160 // Collapse the ...,2,2,... dimensions into ...,4,...
161 return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims});
162 }
163
164 // Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the
165 // resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels].
166 // This function can also be seen as the inverse of 'Gather2by2Neighbors'.
ScatterToGradData(XlaOpKernelContext * ctx,XlaOp grad_data,XlaOp indices,XlaOp updates,int64 warp_dims,xla::PrimitiveType xla_type)167 XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices,
168 XlaOp updates, int64 warp_dims,
169 xla::PrimitiveType xla_type) {
170 xla::ScatterDimensionNumbers scatter_dim_numbers;
171 const int64 neighbor_data_dimensions = warp_dims + 2;
172 // Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2,
173 // data_channels], the update window dimensions is the last 3 dimensions.
174 scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3);
175 scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2);
176 scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1);
177 scatter_dim_numbers.set_index_vector_dim(warp_dims - 1);
178
179 scatter_dim_numbers.add_inserted_window_dims(0);
180 scatter_dim_numbers.add_scatter_dims_to_operand_dims(0);
181 // Since input is of dimension [batch, height(y), width(x), channel], and warp
182 // is of dimension [batch, x, y], the ordering of x, y here needs to be
183 // swapped when scattering.
184 scatter_dim_numbers.add_scatter_dims_to_operand_dims(2);
185 scatter_dim_numbers.add_scatter_dims_to_operand_dims(1);
186
187 return xla::Scatter(grad_data, indices, updates,
188 xla::CreateScalarAddComputation(xla_type, ctx->builder()),
189 scatter_dim_numbers);
190 }
191
192 // Bounds samples to 0 if the warp image indices are out of the (-1, image_size)
193 // bound.
194 // The resulting dimension is given by 'result_dims'.
BoundSamples(XlaOpKernelContext * ctx,XlaOp warp,xla::PrimitiveType warp_type,TensorShape warp_shape,std::vector<int64> result_dims,std::vector<int64> broadcasted_dims,int64 last_warp_dim,xla::Shape data_shape,XlaOp sample)195 XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp,
196 xla::PrimitiveType warp_type, TensorShape warp_shape,
197 std::vector<int64> result_dims,
198 std::vector<int64> broadcasted_dims, int64 last_warp_dim,
199 xla::Shape data_shape, XlaOp sample) {
200 auto is_gt_minus_one =
201 xla::Gt(warp,
202 xla::ConvertElementType(
203 xla::ConstantR1<float>(ctx->builder(), {-1, -1}), warp_type),
204 /*broadcast_dimensions=*/{warp_shape.dims() - 1});
205 auto is_lt_image_size = xla::Lt(
206 warp,
207 xla::ConvertElementType(
208 xla::ConstantR1<float>(
209 ctx->builder(),
210 {/*width=*/static_cast<float>(data_shape.dimensions(2)),
211 /*height=*/static_cast<float>(data_shape.dimensions(1))}),
212 warp_type),
213 /*broadcast_dimensions=*/{warp_shape.dims() - 1});
214
215 auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size);
216 // Reduce along last dimension. The resulting dimension is:
217 // [batch, dim_0, ...dim_n].
218 auto is_in_bound = xla::Reduce(
219 is_in_bound_padded_x_y, xla::ConstantR0<bool>(ctx->builder(), true),
220 xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()),
221 {last_warp_dim});
222
223 // Broadcast 'is_in_bound' to the same dimension as 'result_dims'.
224 auto broadcasted_is_in_bound =
225 xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims);
226
227 // Set out of bound samples to zero.
228 auto zeros =
229 xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims);
230 return xla::Select(broadcasted_is_in_bound, sample, zeros);
231 }
232
233 // Build computation the backprop into input 'data'.
234 // Where input:
235 // grad_output is of dimension [batch, dim_0, ...dim_n, channel]
236 // ratio is of dimension [batch, dim_0, ...dim_n, 2]
237 // gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
238 // data_shape is of dimension [batch, x(width), y(height), channel]
239 //
240 // Output:
241 // scatter-add to each 2x2 grad_data neighbor:
242 // grad_data[fx, fy, chan] += output_grad * dx * dy
243 // grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy
244 // grad_data[fx, cy, chan] += output_grad * dx * (1 - dy)
245 // grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy)
246 // where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their
247 // contribution is 0 to 'grad_data'.
CalculateGradData(XlaOpKernelContext * ctx,XlaOp grad_output,XlaOp ratio,XlaOp gather_indices,XlaOp warp,xla::PrimitiveType warp_type,TensorShape warp_shape,int64 last_warp_dim,int64 data_channels,xla::Shape data_shape)248 XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
249 XlaOp gather_indices, XlaOp warp,
250 xla::PrimitiveType warp_type, TensorShape warp_shape,
251 int64 last_warp_dim, int64 data_channels,
252 xla::Shape data_shape) {
253 // Weights tensor has dimension [batch, dim_0, ... dim_n, 4].
254 auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type);
255
256 auto warp_dims = warp_shape.dim_sizes();
257 std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
258 warp_dims.end() - 1);
259
260 std::vector<int64> reshaped_weights_dims = warp_dims_without_last_dims;
261 // Reshape the last dimension of size 4 to two dimensions [2, 2].
262 reshaped_weights_dims.push_back(2);
263 reshaped_weights_dims.push_back(2);
264 std::vector<int64> reshape_dims(warp_shape.dims());
265 std::iota(reshape_dims.begin(), reshape_dims.end(), 0);
266 // The dimension is [batch, dim_0,..., dim_n, 2, 2].
267 auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims,
268 /*new_sizes=*/reshaped_weights_dims);
269
270 std::vector<int64> weights_with_channels_dims = reshaped_weights_dims;
271 weights_with_channels_dims.push_back(data_channels);
272 std::vector<int64> reshaped_weights_indices(reshaped_weights_dims.size());
273 std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(),
274 0);
275
276 // Set out of bound weights to 0.
277 // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2].
278 std::vector<int64> reshaped_result_dims(warp_dims.begin(),
279 warp_dims.end() - 1);
280 reshaped_result_dims.push_back(2);
281 reshaped_result_dims.push_back(2);
282 std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
283 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
284 reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape,
285 reshaped_result_dims, broadcasted_dims,
286 last_warp_dim, data_shape, reshaped_weights);
287
288 // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel].
289 auto broadcast_reshaped_weights = xla::BroadcastInDim(
290 reshaped_weights, weights_with_channels_dims, reshaped_weights_indices);
291
292 std::vector<int64> grad_output_indices(warp_dims_without_last_dims.size());
293 std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0);
294 grad_output_indices.push_back(weights_with_channels_dims.size() - 1);
295 XlaOp broadcast_grad_output = xla::BroadcastInDim(
296 grad_output, weights_with_channels_dims, grad_output_indices);
297
298 auto grad_output_multiply_weights =
299 broadcast_grad_output * broadcast_reshaped_weights;
300
301 auto grad_data = xla::ConstantLiteral(
302 ctx->builder(), xla::Literal::CreateFromShape(data_shape));
303
304 // Pad grad data then slice it back.
305 //
306 // After left and right column 0-padding, the new dimension of padded data
307 // will be [batch, x+2, y+2, channel].
308 auto padded_grad_data =
309 xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type),
310 xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
311
312 auto shifting_value = xla::ConstantR1<int32>(
313 ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
314 auto shifted_gather_indices =
315 xla::Add(gather_indices, shifting_value, {last_warp_dim});
316
317 auto updated_grad_data = ScatterToGradData(
318 ctx, padded_grad_data, shifted_gather_indices,
319 grad_output_multiply_weights, warp_shape.dims(), warp_type);
320
321 const int64 batch_size = data_shape.dimensions(0);
322 const int64 width = data_shape.dimensions(1);
323 const int64 height = data_shape.dimensions(2);
324 // Slice out the result accounting for the padding.
325 return xla::Slice(
326 updated_grad_data, /*start_indices=*/{0, 1, 1, 0},
327 /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels},
328 /*strides=*/{1, 1, 1, 1});
329 }
330
331 // Build computation for the backprop into input 'warp'.
332 // Where input:
333 // warp is of dimension [batch, dim_0, ...dim_n, 2]
334 // grad_output is of dimension [batch, dim_0, ...dim_n, channel]
335 // ratio is of dimension [batch, dim_0, ...dim_n, 2]
336 // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last
337 // dimension of size 3 is for {batch, x(width), y(height)}.
338 // data is of dimension [batch, x, y, channel]
339 //
340 // Output (simplified by ignoring the batch dimensions):
341 // Since the forward path has:
342 // output = dot(weights * neighbors)
343 // The backprop into warp will therefore be:
344 // grad_warp = output_grad * d_output / d_warp
345 // = output_grad * (d_weights / d_warp * neighbors + d_neighbors /
346 // d_warp * weight)
347 // Where:
348 // d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py]
349 // d_weights / d_warp_y = [-(1 - px), -px, (1-px), px]
350 // and
351 // d_neighbors / d_warp_x = 0
352 //
353 // Therefore:
354 // grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy)
355 // grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy)
356 //
357 // where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the
358 // bottom right corner in a 2x2 neighborhood.
CalculateGradWarp(XlaOpKernelContext * ctx,XlaOp grad_output,XlaOp ratio,XlaOp gather_indices,XlaOp data,TensorShape warp_shape,int64 data_channels,xla::PrimitiveType data_type,xla::Shape data_shape)359 XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
360 XlaOp gather_indices, XlaOp data,
361 TensorShape warp_shape, int64 data_channels,
362 xla::PrimitiveType data_type, xla::Shape data_shape) {
363 auto warp_dims = warp_shape.dim_sizes();
364 std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
365 warp_dims.end() - 1);
366
367 // With dimension [batch, dim_0, ...dim_n, 4]
368 std::vector<int64> neighbor_broadcast_dims = warp_dims_without_last_dims;
369 neighbor_broadcast_dims.push_back(4);
370
371 // With dimension [batch, dim_0, ...dim_n, 4]
372 auto neighbor_broadcast_shape =
373 xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims);
374
375 const int64 last_warp_dim = warp_shape.dims() - 1;
376
377 // Pad data with 0, before gathering such that 0 will be returned for samples
378 // in the range of (-1, 0) or (image_dimension-1, image_dimension).
379 // After left and right column 0-padding, the new dimension of padded data
380 // will be [batch, x+2, y+2, channel].
381 auto padded_data =
382 xla::Pad(data, xla::Zero(ctx->builder(), data_type),
383 xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
384
385 auto shifting_value = xla::ConstantR1<int32>(
386 ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
387 auto shifted_gather_indices =
388 xla::Add(gather_indices, shifting_value, {last_warp_dim});
389
390 // The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
391 auto neighbors_data =
392 Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices,
393 data_channels, warp_shape.dims());
394
395 // Since we will be creating the dot product of:
396 // lhs: [batch, dim_0, ...dim_n, 4]
397 // and
398 // rhs: [batch, dim_0, ...dim_n, 4, data_channels]
399 // we choose the last dimension of lhs and the second last dimension of rhs,
400 // with size 4, as the contracting dimension.
401 xla::DotDimensionNumbers dot_dims;
402 for (int i = 0; i < warp_shape.dims() - 1; ++i) {
403 dot_dims.add_lhs_batch_dimensions(i);
404 dot_dims.add_rhs_batch_dimensions(i);
405 }
406 dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
407 dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
408
409 // img_cxcy - img_fxcy
410 auto bottom_right_minus_bottom_left = xla::DotGeneral(
411 xla::BroadcastInDim(
412 xla::ConvertElementType(
413 xla::ConstantR1<float>(ctx->builder(), {0, 0, -1, 1}), data_type),
414 neighbor_broadcast_dims, {last_warp_dim}),
415 neighbors_data, dot_dims, /*precision_config=*/nullptr);
416
417 // img_cxfy - img_fxfy
418 auto top_right_minus_top_left = xla::DotGeneral(
419 xla::BroadcastInDim(
420 xla::ConvertElementType(
421 xla::ConstantR1<float>(ctx->builder(), {-1, 1, 0, 0}), data_type),
422 neighbor_broadcast_dims, {last_warp_dim}),
423 neighbors_data, dot_dims, /*precision_config=*/nullptr);
424
425 // img_cxcy - img_cxfy
426 auto bottom_right_minus_top_right = xla::DotGeneral(
427 xla::BroadcastInDim(
428 xla::ConvertElementType(
429 xla::ConstantR1<float>(ctx->builder(), {0, -1, 0, 1}), data_type),
430 neighbor_broadcast_dims, {last_warp_dim}),
431 neighbors_data, dot_dims, /*precision_config=*/nullptr);
432
433 // img_fxcy - img_fxfy
434 auto bottom_left_minus_top_left = xla::DotGeneral(
435 xla::BroadcastInDim(
436 xla::ConvertElementType(
437 xla::ConstantR1<float>(ctx->builder(), {-1, 0, 1, 0}), data_type),
438 neighbor_broadcast_dims, {last_warp_dim}),
439 neighbors_data, dot_dims, /*precision_config=*/nullptr);
440
441 // Slice out x and y.
442 auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1,
443 /*stride=*/1, /*dimno=*/last_warp_dim);
444 auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2,
445 /*stride=*/1, /*dimno=*/last_warp_dim);
446
447 // Build 1 - y and 1 - x.
448 auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y;
449 auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x;
450
451 auto x_before_reduce =
452 grad_output * weight_y * bottom_right_minus_bottom_left +
453 one_minus_y * top_right_minus_top_left;
454
455 std::vector<int64> reshaped_sizes = warp_dims_without_last_dims;
456 reshaped_sizes.push_back(1);
457
458 std::vector<int64> reshaped_dims(warp_dims_without_last_dims.size());
459 std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0);
460
461 // Reduce-add along the channel dimension.
462 auto x_result =
463 xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type),
464 xla::CreateScalarAddComputation(data_type, ctx->builder()),
465 {last_warp_dim});
466 // Reshape before concatenating with y values.
467 XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes);
468
469 auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right +
470 one_minus_x * bottom_left_minus_top_left;
471 // Reduce-add along the channel dimension.
472 auto y_result =
473 xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type),
474
475 xla::CreateScalarAddComputation(data_type, ctx->builder()),
476 {last_warp_dim});
477 XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes);
478
479 return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y},
480 last_warp_dim);
481 }
482
483 class ResamplerOp : public XlaOpKernel {
484 public:
ResamplerOp(OpKernelConstruction * ctx)485 explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
486
Compile(XlaOpKernelContext * ctx)487 void Compile(XlaOpKernelContext* ctx) override {
488 TensorShape data_shape = ctx->InputShape("data");
489 OP_REQUIRES(ctx, data_shape.dims() == 4,
490 errors::InvalidArgument("data must be 4-dimensional",
491 data_shape.DebugString()));
492 const int64 data_channels = data_shape.dim_size(3);
493 xla::PrimitiveType data_type = ctx->input_xla_type(0);
494
495 TensorShape warp_shape = ctx->InputShape("warp");
496 OP_REQUIRES(ctx, warp_shape.dims() >= 2,
497 errors::InvalidArgument("warp must be at least 2-dimensional",
498 warp_shape.DebugString()));
499 for (int size : warp_shape.dim_sizes()) {
500 OP_REQUIRES(ctx, size > 0,
501 errors::InvalidArgument("warp sizes must be positive, got [",
502 size, "]"));
503 }
504 const int64 last_warp_dim = warp_shape.dims() - 1;
505 // Last dimension of warp shape must be of size 2.
506 OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
507 errors::InvalidArgument(
508 "the last dimension of warp must be exactly size 2."));
509 xla::PrimitiveType warp_type = ctx->input_xla_type(1);
510
511 XlaOp data = ctx->Input("data");
512 XlaOp warp = ctx->Input("warp");
513
514 // Find the coordinates of the top left corner for the 2x2 region to be
515 // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the
516 // last dimension of size 2 in turn is [x, y].
517 XlaOp top_left = xla::ConvertElementType(warp, xla::S32);
518
519 auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
520
521 // The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
522 auto neighbors_data = Gather2by2Neighbors(
523 ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
524
525 // Dimensions are [batch, dim_0, ... dim_n, 2].
526 XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type);
527
528 // Obtain the bilinear blending weights, the dimension is [batch, dim_0,
529 // ...dim_n, 4].
530 auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type);
531
532 // Since we will be creating the dot product of:
533 // lhs: [batch, dim_0, ...dim_n, 4]
534 // and
535 // rhs: [batch, dim_0, ...dim_n, 4, data_channels]
536 // we choose the last dimension of lhs and the second last dimension of rhs,
537 // with size 4, as the contracting dimension.
538 xla::DotDimensionNumbers dot_dims;
539 for (int i = 0; i < warp_shape.dims() - 1; ++i) {
540 dot_dims.add_lhs_batch_dimensions(i);
541 dot_dims.add_rhs_batch_dimensions(i);
542 }
543 dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
544 dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
545
546 // The dimension is [batch, dim_0, ...dim_n, data_channels].
547 auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims,
548 /*precision_config=*/nullptr);
549
550 // Handle out of boundary cases by constructing a predicate mask array based
551 // on the in-bound condition, and output 0 for the blended pixel value if
552 // out-bound. The dimension is the same as top_left: [batch, dim_0,
553 // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate.
554
555 auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp));
556
557 auto is_lt_image_size = xla::Lt(
558 warp,
559 xla::ConvertElementType(
560 xla::ConstantR1<float>(
561 ctx->builder(),
562 {/*width=*/static_cast<float>(data_shape.dim_size(2) - 1),
563 /*height=*/static_cast<float>(data_shape.dim_size(1) - 1)}),
564 warp_type),
565 /*broadcast_dimensions=*/{warp_shape.dims() - 1});
566
567 auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size);
568 // Reduce along last dimension. The resulting dimension is:
569 // [batch, dim_0, ...dim_n].
570 auto is_in_bound = xla::Reduce(
571 is_in_bound_x_y, xla::ConstantR0<bool>(ctx->builder(), true),
572 xla::CreateScalarAndComputation(xla::PrimitiveType::PRED,
573 ctx->builder()),
574 {last_warp_dim});
575
576 // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which
577 // is the dimension of the result:
578 // [batch, dim_0, ...dim_n, data_channels].
579 auto warp_dims = warp_shape.dim_sizes();
580 std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1);
581 result_dims.push_back(data_channels);
582
583 std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
584 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
585 auto broadcasted_is_in_bound =
586 xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims);
587
588 // Set out of bound samples to zero.
589 auto zeros =
590 xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims);
591 auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros);
592
593 ctx->SetOutput(0, result);
594 }
595 };
596
597 REGISTER_XLA_OP(Name("Resampler"), ResamplerOp);
598
599 class ResamplerGradOp : public XlaOpKernel {
600 public:
ResamplerGradOp(OpKernelConstruction * ctx)601 explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
602 DataType output_dtype;
603 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
604 }
605
606 // TODO(b/112295522): note that sampling from image boundary is not currently
607 // being handled properly.
Compile(XlaOpKernelContext * ctx)608 void Compile(XlaOpKernelContext* ctx) override {
609 TensorShape data_shape_tf = ctx->InputShape("data");
610 OP_REQUIRES(ctx, data_shape_tf.dims() == 4,
611 errors::InvalidArgument("data must be 4-dimensional",
612 data_shape_tf.DebugString()));
613 const int64 data_channels = data_shape_tf.dim_size(3);
614 xla::PrimitiveType data_type = ctx->input_xla_type(0);
615
616 TensorShape warp_shape = ctx->InputShape("warp");
617 OP_REQUIRES(ctx, warp_shape.dims() >= 2,
618 errors::InvalidArgument("warp must be at least 2-dimensional",
619 warp_shape.DebugString()));
620 for (int size : warp_shape.dim_sizes()) {
621 OP_REQUIRES(ctx, size > 0,
622 errors::InvalidArgument("warp sizes must be positive, got [",
623 size, "]"));
624 }
625 // Last dimension of warp shape must be of size 2.
626 const int64 last_warp_dim = warp_shape.dims() - 1;
627 OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
628 errors::InvalidArgument(
629 "the last dimension of warp must be exactly size 2."));
630 xla::PrimitiveType warp_type = ctx->input_xla_type(1);
631
632 TensorShape output_grad_shape = ctx->InputShape("grad_output");
633 OP_REQUIRES(
634 ctx, output_grad_shape.dims() >= 2,
635 errors::InvalidArgument("output_grad must be at least 2-dimensional",
636 output_grad_shape.DebugString()));
637
638 // Dimensions are [batch, x, y, channel].
639 XlaOp data = ctx->Input("data");
640 xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf);
641
642 // Dimensions are [batch, dim_0, ...dim_n, 2].
643 XlaOp warp = ctx->Input("warp");
644 // Dimensions are [batch, dim_0, ...dim_n, channel].
645 XlaOp grad_output = ctx->Input("grad_output");
646
647 // Find the top left corner coordinate for the region to be sampled from.
648 // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension
649 // of size 2 in turn is [x, y].
650 XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32);
651
652 // Dimensions are [batch, dim_0, ... dim_n, 2].
653 XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type);
654
655 // Indices for gathering neighboring pixels.
656 auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
657
658 auto grad_data = CalculateGradData(
659 ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape,
660 last_warp_dim, data_channels, data_shape);
661
662 auto grad_warp =
663 CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data,
664 warp_shape, data_channels, data_type, data_shape);
665 auto warp_dims = warp_shape.dim_sizes();
666 std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1);
667 result_dims.push_back(2);
668 std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
669 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
670 auto grad_warp_bounded =
671 BoundSamples(ctx, warp, warp_type, warp_shape, result_dims,
672 broadcasted_dims, last_warp_dim, data_shape, grad_warp);
673
674 ctx->SetOutput(0, grad_data);
675 ctx->SetOutput(1, grad_warp_bounded);
676 }
677 };
678
679 REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp);
680
681 } // namespace
682 } // namespace tensorflow
683