1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA-specific Ops for 2D convolution.
17
18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
19
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/framework/bounds_check.h"
32 #include "tensorflow/core/framework/kernel_shape_util.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/numeric_op.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_slice.h"
39 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42
43 namespace tensorflow {
44 namespace {
45
46 // Returns the expanded size of a filter used for depthwise convolution.
47 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., 1, M*N].
GroupedFilterShapeForDepthwiseConvolution(const xla::Shape & filter_shape)48 xla::Shape GroupedFilterShapeForDepthwiseConvolution(
49 const xla::Shape& filter_shape) {
50 int64 input_feature_dim = filter_shape.dimensions_size() - 2;
51 int64 output_feature_dim = filter_shape.dimensions_size() - 1;
52 int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
53 int64 input_feature = filter_shape.dimensions(input_feature_dim);
54
55 // Create a [H, W, ..., 1, M*N] reshape of the filter.
56 xla::Shape grouped_filter_shape = filter_shape;
57 grouped_filter_shape.set_dimensions(input_feature_dim, 1);
58 grouped_filter_shape.set_dimensions(output_feature_dim,
59 depthwise_multiplier * input_feature);
60 return grouped_filter_shape;
61 }
62
63 // Returns the transposed filter for use in BackpropInput of group convolution.
TransposeFilterForGroupConvolutionBackpropInput(xla::XlaOp filter,const xla::Shape & filter_shape,int64 num_groups,int num_spatial_dims)64 xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput(
65 xla::XlaOp filter, const xla::Shape& filter_shape, int64 num_groups,
66 int num_spatial_dims) {
67 // 1. Reshape from [H, W, ..., filter_in_depth, out_depth] to [H, W, ...,
68 // filter_in_depth, G, out_depth / G]
69 int num_dims = filter_shape.dimensions_size();
70 CHECK_GE(num_dims, 2); // Crash OK
71 xla::Shape new_shape = filter_shape;
72 new_shape.set_dimensions(num_dims - 1, num_groups);
73 new_shape.add_dimensions(filter_shape.dimensions(num_dims - 1) / num_groups);
74 xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions());
75
76 // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]
77 std::vector<int64> transpose_dims(num_dims + 1);
78 std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
79 std::swap(transpose_dims[num_spatial_dims],
80 transpose_dims[num_spatial_dims + 1]);
81 result = xla::Transpose(result, transpose_dims);
82
83 // 3. Reshape to [H, W, ..., in_depth, out_depth / G]
84 result = xla::Collapse(result, {num_spatial_dims, num_spatial_dims + 1});
85 return result;
86 }
87
88 // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
89 // build a depthwise convolution.
ReshapeFilterForDepthwiseConvolution(const xla::Shape & filter_shape,xla::XlaOp filter)90 xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
91 xla::XlaOp filter) {
92 return xla::Reshape(
93 filter,
94 GroupedFilterShapeForDepthwiseConvolution(filter_shape).dimensions());
95 }
96
97 // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
98 // convolutions (as currently implemented).
CheckConvAttrs(const ConvOpAttrs & attrs)99 Status CheckConvAttrs(const ConvOpAttrs& attrs) {
100 const int num_dims = attrs.num_spatial_dims + 2;
101 const int attrs_strides_size = attrs.strides.size();
102 if (attrs_strides_size != num_dims) {
103 return errors::InvalidArgument("Sliding window strides field must specify ",
104 num_dims, " dimensions");
105 }
106 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
107 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
108 if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
109 return errors::Unimplemented(
110 "Current implementation does not yet support strides in the batch and "
111 "depth dimensions.");
112 }
113 const int attrs_dilations_size = attrs.dilations.size();
114 if (attrs_dilations_size != num_dims) {
115 return errors::InvalidArgument("Dilations field must specify ", num_dims,
116 " dimensions");
117 }
118 if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
119 return errors::Unimplemented(
120 "Current implementation does not support dilations in the batch and "
121 "depth dimensions.");
122 }
123 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
124 int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
125 if (attrs.dilations[input_dim] < 1) {
126 return errors::Unimplemented("Dilation values must be positive; ", i,
127 "th spatial dimension had dilation ",
128 attrs.dilations[input_dim]);
129 }
130 }
131 return Status::OK();
132 }
133
134 // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
135 // to TensorShapes.
ConvBackpropComputeDimensionsV2XlaShapes(StringPiece label,int num_spatial_dims,const xla::Shape & input_shape,const xla::Shape & filter_shape,const xla::Shape & out_backprop_shape,absl::Span<const int32> dilations,const std::vector<int32> & strides,Padding padding,TensorFormat data_format,ConvBackpropDimensions * dims,absl::Span<const int64> explicit_paddings)136 Status ConvBackpropComputeDimensionsV2XlaShapes(
137 StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
138 const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
139 absl::Span<const int32> dilations, const std::vector<int32>& strides,
140 Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
141 absl::Span<const int64> explicit_paddings) {
142 TensorShape input_tensor_shape, filter_tensor_shape,
143 out_backprop_tensor_shape;
144 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
145 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
146 TF_RETURN_IF_ERROR(
147 XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
148 return ConvBackpropComputeDimensionsV2(
149 label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
150 out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
151 data_format, dims);
152 }
153
154 } // anonymous namespace
155
GetXlaConvTypes()156 std::vector<DataType> GetXlaConvTypes() {
157 return {DT_FLOAT, DT_BFLOAT16, DT_HALF, DT_DOUBLE};
158 }
159
Create(int num_spatial_dims,bool depthwise,OpKernelConstruction * ctx)160 xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
161 bool depthwise,
162 OpKernelConstruction* ctx) {
163 ConvOpAttrs attrs;
164 attrs.num_spatial_dims = num_spatial_dims;
165 attrs.depthwise = depthwise;
166 TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
167 TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
168 TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
169 if (attrs.padding == EXPLICIT) {
170 TF_RETURN_IF_ERROR(
171 ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
172 }
173
174 string data_format;
175 TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
176 if (!FormatFromString(data_format, &attrs.data_format)) {
177 return errors::InvalidArgument("Invalid data format: ", data_format);
178 }
179
180 TF_RETURN_IF_ERROR(CheckValidPadding(attrs.padding, attrs.explicit_paddings,
181 /*num_dims=*/num_spatial_dims + 2,
182 attrs.data_format));
183
184 return attrs;
185 }
186
MakeXlaForwardConvOp(StringPiece,xla::XlaOp conv_input,xla::XlaOp filter,const ConvOpAttrs & attrs,const xla::PrecisionConfig * precision_config)187 xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
188 StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter,
189 const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
190 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
191
192 auto* builder = conv_input.builder();
193 TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
194 // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
195 TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
196
197 // For 2D convolution, there should be 4 dimensions.
198 int num_dims = attrs.num_spatial_dims + 2;
199 if (input_shape.dimensions_size() != num_dims) {
200 return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
201 input_shape.DebugString());
202 }
203 if (filter_shape.dimensions_size() != num_dims) {
204 return errors::InvalidArgument(
205 "filter must be ", num_dims,
206 "-dimensional: ", filter_shape.DebugString());
207 }
208
209 // The last two dimensions of the filter are the input and output shapes.
210 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
211 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
212
213 int64 filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
214 out_depth = filter_shape.dimensions(attrs.num_spatial_dims + 1),
215 in_depth = input_shape.dimensions(feature_dim);
216 // The 'C' dimension for input is in_depth.
217 // It must be a multiple of the filter's in_depth.
218 if (in_depth % filter_in_depth != 0) {
219 return errors::InvalidArgument(
220 "Depth of input must be a multiple of depth of filter: ", in_depth,
221 " vs ", filter_in_depth);
222 }
223 int64 feature_group_count = in_depth / filter_in_depth;
224 if (out_depth % feature_group_count != 0) {
225 return errors::InvalidArgument(
226 "Depth of output must be a multiple of the number of groups: ",
227 out_depth, " vs ", feature_group_count);
228 }
229
230 if (attrs.depthwise) {
231 filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
232 }
233
234 xla::ConvolutionDimensionNumbers dims;
235 std::vector<int64> window_strides(attrs.num_spatial_dims);
236 std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
237 std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
238 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
239
240 dims.set_input_batch_dimension(batch_dim);
241 dims.set_output_batch_dimension(batch_dim);
242 dims.set_input_feature_dimension(feature_dim);
243 dims.set_output_feature_dimension(feature_dim);
244 dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
245 dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
246 xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
247 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
248 const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
249 if (input_shape.is_dynamic_dimension(dim)) {
250 TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
251 << "Dynamic convolution only supports valid and same padding";
252 if (attrs.padding == VALID) {
253 padding_type = xla::PaddingType::PADDING_VALID;
254 }
255 if (attrs.padding == SAME) {
256 padding_type = xla::PaddingType::PADDING_SAME;
257 }
258 }
259 dims.add_input_spatial_dimensions(dim);
260 dims.add_kernel_spatial_dimensions(i);
261 dims.add_output_spatial_dimensions(dim);
262 window_strides[i] = attrs.strides.at(dim);
263 rhs_dilation[i] = attrs.dilations.at(dim);
264
265 if (attrs.padding == EXPLICIT) {
266 padding[i] = {attrs.explicit_paddings.at(dim * 2),
267 attrs.explicit_paddings.at(dim * 2 + 1)};
268 }
269
270 int64 unused_output_size;
271 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
272 input_shape.dimensions(dim), filter_shape.dimensions(i),
273 rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
274 &padding[i].first, &padding[i].second));
275 }
276
277 if (padding_type != xla::PaddingType::PADDING_INVALID) {
278 return xla::DynamicConvForward(
279 conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
280 dims,
281 /*feature_group_count=*/attrs.depthwise ? in_depth
282 : feature_group_count,
283 /*batch_group_count=*/1, precision_config, padding_type);
284 }
285
286 return xla::ConvGeneralDilated(
287 conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
288 dims,
289 /*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
290 /*batch_group_count=*/1, precision_config);
291 }
292
MakeXlaBackpropInputConvOp(StringPiece type_string,const xla::Shape & input_shape,xla::XlaOp filter,xla::XlaOp out_backprop,const ConvOpAttrs & attrs,const xla::PrecisionConfig * precision_config,xla::XlaOp * input_sizes)293 xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
294 StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
295 xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
296 const xla::PrecisionConfig* precision_config, xla::XlaOp* input_sizes) {
297 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
298
299 int num_dims = attrs.num_spatial_dims + 2;
300 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
301 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
302
303 auto* builder = filter.builder();
304 TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
305 TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
306 builder->GetShape(out_backprop));
307
308 int64 in_depth = input_shape.dimensions(feature_dim),
309 filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
310 feature_group_count =
311 attrs.depthwise ? filter_in_depth : in_depth / filter_in_depth;
312
313 xla::Shape grouped_filter_shape =
314 attrs.depthwise ? GroupedFilterShapeForDepthwiseConvolution(filter_shape)
315 : filter_shape;
316 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
317 ConvBackpropDimensions dims;
318 TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
319 type_string, attrs.num_spatial_dims, input_shape, grouped_filter_shape,
320 out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
321 attrs.data_format, &dims, attrs.explicit_paddings));
322
323 // The input gradients are computed by a convolution of the output
324 // gradients and the filter, with some appropriate padding. See the
325 // comment at the top of conv_grad_shape_utils.h for details.
326
327 xla::ConvolutionDimensionNumbers dnums;
328 dnums.set_input_batch_dimension(batch_dim);
329 dnums.set_output_batch_dimension(batch_dim);
330 dnums.set_input_feature_dimension(feature_dim);
331 dnums.set_output_feature_dimension(feature_dim);
332
333 // TF filter shape is [ H, W, ..., inC, outC ]
334 // Transpose the input and output features for computing the gradient.
335 dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
336 dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
337
338 std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
339 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
340 std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
341 std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
342 std::vector<int64> ones(attrs.num_spatial_dims, 1);
343 xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
344 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
345 int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
346 if (out_backprop_shape.is_dynamic_dimension(dim)) {
347 TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
348 << "Dynamic convolution only supports valid and same padding";
349 if (attrs.padding == VALID) {
350 padding_type = xla::PaddingType::PADDING_VALID;
351 }
352 if (attrs.padding == SAME) {
353 padding_type = xla::PaddingType::PADDING_SAME;
354 }
355 }
356 dnums.add_input_spatial_dimensions(dim);
357 dnums.add_kernel_spatial_dimensions(i);
358 dnums.add_output_spatial_dimensions(dim);
359
360 kernel_spatial_dims[i] = i;
361 padding[i] = {dims.spatial_dims[i].pad_before,
362 dims.spatial_dims[i].pad_after};
363 lhs_dilation[i] = dims.spatial_dims[i].stride;
364 rhs_dilation[i] = attrs.dilations[dim];
365 }
366
367 if (feature_group_count != 1 && !attrs.depthwise) {
368 filter = TransposeFilterForGroupConvolutionBackpropInput(
369 filter, filter_shape, feature_group_count, attrs.num_spatial_dims);
370 }
371 // Mirror the filter in the spatial dimensions.
372 filter = xla::Rev(filter, kernel_spatial_dims);
373 if (padding_type != xla::PaddingType::PADDING_INVALID) {
374 TF_RET_CHECK(input_sizes != nullptr);
375 return xla::DynamicConvInputGrad(
376 *input_sizes, out_backprop, filter, /*window_strides=*/ones, padding,
377 lhs_dilation, rhs_dilation, dnums,
378 /*feature_group_count=*/
379 feature_group_count,
380 /*batch_group_count=*/1, precision_config, padding_type);
381 }
382 // activation gradients
383 // = gradients (with padding and dilation) <conv> mirrored_weights
384 return xla::ConvGeneralDilated(out_backprop, filter, /*window_strides=*/ones,
385 padding, lhs_dilation, rhs_dilation, dnums,
386 /*feature_group_count=*/
387 feature_group_count,
388 /*batch_group_count=*/1, precision_config);
389 }
390
MakeXlaBackpropFilterConvOp(StringPiece type_string,xla::XlaOp activations,const xla::Shape & filter_shape,xla::XlaOp gradients,const ConvOpAttrs & attrs,const xla::PrecisionConfig * precision_config)391 xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
392 StringPiece type_string, xla::XlaOp activations,
393 const xla::Shape& filter_shape, xla::XlaOp gradients,
394 const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
395 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
396
397 auto* builder = activations.builder();
398 TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
399 builder->GetShape(activations));
400 TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
401 builder->GetShape(gradients));
402 xla::XlaOp filter_backprop;
403
404 xla::Shape input_shape = activations_shape;
405 xla::Shape output_shape = out_backprop_shape;
406
407 TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
408 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
409 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
410 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
411
412 const xla::Shape grouped_filter_shape =
413 attrs.depthwise ? GroupedFilterShapeForDepthwiseConvolution(filter_shape)
414 : filter_shape;
415 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
416 ConvBackpropDimensions dims;
417 // The filter gradients are computed by a convolution of the input
418 // activations and the output gradients, with some appropriate padding.
419 // See the comment at the top of conv_grad_shape_utils.h for details.
420 xla::ConvolutionDimensionNumbers dnums;
421
422 TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
423 type_string, attrs.num_spatial_dims, activations_shape,
424 grouped_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
425 attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
426
427 // Obtain some useful dimensions:
428 // The last two dimensions of the filter are the input and output shapes.
429 int num_dims = attrs.num_spatial_dims + 2;
430 int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
431 int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
432 int64 in_depth = input_shape.dimensions(c_dim),
433 filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
434 batch_group_count =
435 attrs.depthwise ? filter_in_depth : in_depth / filter_in_depth;
436
437 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
438 std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
439 std::vector<int64> window_strides(attrs.num_spatial_dims);
440 std::vector<int64> ones(attrs.num_spatial_dims, 1);
441
442 // Swap n_dim and c_dim in the activations.
443 dnums.set_input_batch_dimension(c_dim);
444 dnums.set_input_feature_dimension(n_dim);
445
446 // The gradients become the RHS of the convolution.
447 // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
448 // where the batch becomes the input feature for the convolution.
449 dnums.set_kernel_input_feature_dimension(n_dim);
450 dnums.set_kernel_output_feature_dimension(c_dim);
451
452 dnums.set_output_batch_dimension(attrs.num_spatial_dims);
453 dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
454
455 // Tensorflow filter shape is [ H, W, ..., inC, outC ].
456 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
457 dnums.add_output_spatial_dimensions(i);
458 }
459 xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
460 for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
461 int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
462 if (activations_shape.is_dynamic_dimension(dim)) {
463 TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
464 << "Dynamic convolution only supports valid and same padding";
465 if (attrs.padding == VALID) {
466 padding_type = xla::PaddingType::PADDING_VALID;
467 }
468 if (attrs.padding == SAME) {
469 padding_type = xla::PaddingType::PADDING_SAME;
470 }
471 }
472 dnums.add_input_spatial_dimensions(dim);
473 dnums.add_kernel_spatial_dimensions(dim);
474 rhs_dilation[i] = dims.spatial_dims[i].stride;
475 window_strides[i] = attrs.dilations[dim];
476
477 // We will also need to pad the input with zeros such that after the
478 // convolution, we get the right size for the filter.
479 // The padded_in_rows should be such that when we convolve this with the
480 // expanded_out_rows as a filter, we should get filter_rows back.
481
482 const int64 padded_in_size =
483 dims.spatial_dims[i].expanded_output_size +
484 (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
485
486 // However it can be smaller than input_rows: in this
487 // case it means some of the inputs are not used.
488 //
489 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
490 //
491 // INPUT = [ A B C ]
492 //
493 // FILTER = [ x y ]
494 //
495 // and the output will only have one column: a = A * x + B * y
496 //
497 // and input "C" is not used at all.
498 //
499 // We apply negative padding in this case.
500 const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
501
502 // + For the EXPLICIT padding, we pad the top/left side with the explicit
503 // padding and pad the bottom/right side with the remaining space.
504 // + For the VALID padding, we don't pad anything on the top/left side
505 // and pad the bottom/right side with the remaining space.
506 // + For the SAME padding, we pad top/left side the same as bottom/right
507 // side.
508 //
509 // In addition, if the padded input size is smaller than the input size,
510 // we need to ignore some training elements of the input. We do this by
511 // applying negative padding on the right/bottom.
512 const int64 pad_before = attrs.padding == Padding::EXPLICIT
513 ? attrs.explicit_paddings[2 * dim]
514 : attrs.padding == Padding::SAME
515 ? std::max<int64>(pad_total / 2, 0)
516 : 0;
517 padding[i] = {pad_before, pad_total - pad_before};
518 }
519
520 // Besides padding the input, we will also expand output_rows to
521 // expanded_out_rows = (output_rows - 1) * stride + 1
522 // with zeros in between:
523 //
524 // a . . . b . . . c . . . d . . . e
525 //
526 // This is done by specifying the window dilation factors in the
527 // convolution HLO below.
528 if (padding_type != xla::PaddingType::PADDING_INVALID) {
529 filter_backprop = xla::DynamicConvKernelGrad(
530 activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
531 rhs_dilation, dnums,
532 /*feature_group_count=*/1,
533 /*batch_group_count=*/batch_group_count, precision_config,
534 padding_type);
535 } else {
536 filter_backprop = xla::ConvGeneralDilated(
537 activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
538 rhs_dilation, dnums,
539 /*feature_group_count=*/1,
540 /*batch_group_count=*/batch_group_count, precision_config);
541 }
542
543 if (attrs.depthwise) {
544 filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions());
545 }
546
547 return filter_backprop;
548 }
549
550 } // namespace tensorflow
551