1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Ops for 2D convolution.
17 
18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
19 
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/framework/bounds_check.h"
32 #include "tensorflow/core/framework/kernel_shape_util.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/numeric_op.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_slice.h"
39 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42 
43 namespace tensorflow {
44 namespace {
45 
46 // Returns the expanded size of a filter used for depthwise convolution.
47 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., 1, M*N].
GroupedFilterShapeForDepthwiseConvolution(const xla::Shape & filter_shape)48 xla::Shape GroupedFilterShapeForDepthwiseConvolution(
49     const xla::Shape& filter_shape) {
50   int64 input_feature_dim = filter_shape.dimensions_size() - 2;
51   int64 output_feature_dim = filter_shape.dimensions_size() - 1;
52   int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
53   int64 input_feature = filter_shape.dimensions(input_feature_dim);
54 
55   // Create a [H, W, ..., 1, M*N] reshape of the filter.
56   xla::Shape grouped_filter_shape = filter_shape;
57   grouped_filter_shape.set_dimensions(input_feature_dim, 1);
58   grouped_filter_shape.set_dimensions(output_feature_dim,
59                                       depthwise_multiplier * input_feature);
60   return grouped_filter_shape;
61 }
62 
63 // Returns the transposed filter for use in BackpropInput of group convolution.
TransposeFilterForGroupConvolutionBackpropInput(xla::XlaOp filter,const xla::Shape & filter_shape,int64 num_groups,int num_spatial_dims)64 xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput(
65     xla::XlaOp filter, const xla::Shape& filter_shape, int64 num_groups,
66     int num_spatial_dims) {
67   // 1. Reshape from [H, W, ..., filter_in_depth, out_depth] to [H, W, ...,
68   // filter_in_depth, G, out_depth / G]
69   int num_dims = filter_shape.dimensions_size();
70   CHECK_GE(num_dims, 2);  // Crash OK
71   xla::Shape new_shape = filter_shape;
72   new_shape.set_dimensions(num_dims - 1, num_groups);
73   new_shape.add_dimensions(filter_shape.dimensions(num_dims - 1) / num_groups);
74   xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions());
75 
76   // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]
77   std::vector<int64> transpose_dims(num_dims + 1);
78   std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
79   std::swap(transpose_dims[num_spatial_dims],
80             transpose_dims[num_spatial_dims + 1]);
81   result = xla::Transpose(result, transpose_dims);
82 
83   // 3. Reshape to [H, W, ..., in_depth, out_depth / G]
84   result = xla::Collapse(result, {num_spatial_dims, num_spatial_dims + 1});
85   return result;
86 }
87 
88 // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
89 // build a depthwise convolution.
ReshapeFilterForDepthwiseConvolution(const xla::Shape & filter_shape,xla::XlaOp filter)90 xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
91                                                 xla::XlaOp filter) {
92   return xla::Reshape(
93       filter,
94       GroupedFilterShapeForDepthwiseConvolution(filter_shape).dimensions());
95 }
96 
97 // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
98 // convolutions (as currently implemented).
CheckConvAttrs(const ConvOpAttrs & attrs)99 Status CheckConvAttrs(const ConvOpAttrs& attrs) {
100   const int num_dims = attrs.num_spatial_dims + 2;
101   const int attrs_strides_size = attrs.strides.size();
102   if (attrs_strides_size != num_dims) {
103     return errors::InvalidArgument("Sliding window strides field must specify ",
104                                    num_dims, " dimensions");
105   }
106   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
107   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
108   if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
109     return errors::Unimplemented(
110         "Current implementation does not yet support strides in the batch and "
111         "depth dimensions.");
112   }
113   const int attrs_dilations_size = attrs.dilations.size();
114   if (attrs_dilations_size != num_dims) {
115     return errors::InvalidArgument("Dilations field must specify ", num_dims,
116                                    " dimensions");
117   }
118   if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
119     return errors::Unimplemented(
120         "Current implementation does not support dilations in the batch and "
121         "depth dimensions.");
122   }
123   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
124     int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
125     if (attrs.dilations[input_dim] < 1) {
126       return errors::Unimplemented("Dilation values must be positive; ", i,
127                                    "th spatial dimension had dilation ",
128                                    attrs.dilations[input_dim]);
129     }
130   }
131   return Status::OK();
132 }
133 
134 // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
135 // to TensorShapes.
ConvBackpropComputeDimensionsV2XlaShapes(StringPiece label,int num_spatial_dims,const xla::Shape & input_shape,const xla::Shape & filter_shape,const xla::Shape & out_backprop_shape,absl::Span<const int32> dilations,const std::vector<int32> & strides,Padding padding,TensorFormat data_format,ConvBackpropDimensions * dims,absl::Span<const int64> explicit_paddings)136 Status ConvBackpropComputeDimensionsV2XlaShapes(
137     StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
138     const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
139     absl::Span<const int32> dilations, const std::vector<int32>& strides,
140     Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
141     absl::Span<const int64> explicit_paddings) {
142   TensorShape input_tensor_shape, filter_tensor_shape,
143       out_backprop_tensor_shape;
144   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
145   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
146   TF_RETURN_IF_ERROR(
147       XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
148   return ConvBackpropComputeDimensionsV2(
149       label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
150       out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
151       data_format, dims);
152 }
153 
154 }  // anonymous namespace
155 
GetXlaConvTypes()156 std::vector<DataType> GetXlaConvTypes() {
157   return {DT_FLOAT, DT_BFLOAT16, DT_HALF, DT_DOUBLE};
158 }
159 
Create(int num_spatial_dims,bool depthwise,OpKernelConstruction * ctx)160 xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
161                                                bool depthwise,
162                                                OpKernelConstruction* ctx) {
163   ConvOpAttrs attrs;
164   attrs.num_spatial_dims = num_spatial_dims;
165   attrs.depthwise = depthwise;
166   TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
167   TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
168   TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
169   if (attrs.padding == EXPLICIT) {
170     TF_RETURN_IF_ERROR(
171         ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
172   }
173 
174   string data_format;
175   TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
176   if (!FormatFromString(data_format, &attrs.data_format)) {
177     return errors::InvalidArgument("Invalid data format: ", data_format);
178   }
179 
180   TF_RETURN_IF_ERROR(CheckValidPadding(attrs.padding, attrs.explicit_paddings,
181                                        /*num_dims=*/num_spatial_dims + 2,
182                                        attrs.data_format));
183 
184   return attrs;
185 }
186 
MakeXlaForwardConvOp(StringPiece,xla::XlaOp conv_input,xla::XlaOp filter,const ConvOpAttrs & attrs,const xla::PrecisionConfig * precision_config)187 xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
188     StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter,
189     const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
190   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
191 
192   auto* builder = conv_input.builder();
193   TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
194   // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
195   TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
196 
197   // For 2D convolution, there should be 4 dimensions.
198   int num_dims = attrs.num_spatial_dims + 2;
199   if (input_shape.dimensions_size() != num_dims) {
200     return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
201                                    input_shape.DebugString());
202   }
203   if (filter_shape.dimensions_size() != num_dims) {
204     return errors::InvalidArgument(
205         "filter must be ", num_dims,
206         "-dimensional: ", filter_shape.DebugString());
207   }
208 
209   // The last two dimensions of the filter are the input and output shapes.
210   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
211   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
212 
213   int64 filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
214         out_depth = filter_shape.dimensions(attrs.num_spatial_dims + 1),
215         in_depth = input_shape.dimensions(feature_dim);
216   // The 'C' dimension for input is in_depth.
217   // It must be a multiple of the filter's in_depth.
218   if (in_depth % filter_in_depth != 0) {
219     return errors::InvalidArgument(
220         "Depth of input must be a multiple of depth of filter: ", in_depth,
221         " vs ", filter_in_depth);
222   }
223   int64 feature_group_count = in_depth / filter_in_depth;
224   if (out_depth % feature_group_count != 0) {
225     return errors::InvalidArgument(
226         "Depth of output must be a multiple of the number of groups: ",
227         out_depth, " vs ", feature_group_count);
228   }
229 
230   if (attrs.depthwise) {
231     filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
232   }
233 
234   xla::ConvolutionDimensionNumbers dims;
235   std::vector<int64> window_strides(attrs.num_spatial_dims);
236   std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
237   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
238   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
239 
240   dims.set_input_batch_dimension(batch_dim);
241   dims.set_output_batch_dimension(batch_dim);
242   dims.set_input_feature_dimension(feature_dim);
243   dims.set_output_feature_dimension(feature_dim);
244   dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
245   dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
246   xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
247   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
248     const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
249     if (input_shape.is_dynamic_dimension(dim)) {
250       TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
251           << "Dynamic convolution only supports valid and same padding";
252       if (attrs.padding == VALID) {
253         padding_type = xla::PaddingType::PADDING_VALID;
254       }
255       if (attrs.padding == SAME) {
256         padding_type = xla::PaddingType::PADDING_SAME;
257       }
258     }
259     dims.add_input_spatial_dimensions(dim);
260     dims.add_kernel_spatial_dimensions(i);
261     dims.add_output_spatial_dimensions(dim);
262     window_strides[i] = attrs.strides.at(dim);
263     rhs_dilation[i] = attrs.dilations.at(dim);
264 
265     if (attrs.padding == EXPLICIT) {
266       padding[i] = {attrs.explicit_paddings.at(dim * 2),
267                     attrs.explicit_paddings.at(dim * 2 + 1)};
268     }
269 
270     int64 unused_output_size;
271     TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
272         input_shape.dimensions(dim), filter_shape.dimensions(i),
273         rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
274         &padding[i].first, &padding[i].second));
275   }
276 
277   if (padding_type != xla::PaddingType::PADDING_INVALID) {
278     return xla::DynamicConvForward(
279         conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
280         dims,
281         /*feature_group_count=*/attrs.depthwise ? in_depth
282                                                 : feature_group_count,
283         /*batch_group_count=*/1, precision_config, padding_type);
284   }
285 
286   return xla::ConvGeneralDilated(
287       conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
288       dims,
289       /*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
290       /*batch_group_count=*/1, precision_config);
291 }
292 
MakeXlaBackpropInputConvOp(StringPiece type_string,const xla::Shape & input_shape,xla::XlaOp filter,xla::XlaOp out_backprop,const ConvOpAttrs & attrs,const xla::PrecisionConfig * precision_config,xla::XlaOp * input_sizes)293 xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
294     StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
295     xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
296     const xla::PrecisionConfig* precision_config, xla::XlaOp* input_sizes) {
297   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
298 
299   int num_dims = attrs.num_spatial_dims + 2;
300   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
301   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
302 
303   auto* builder = filter.builder();
304   TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
305   TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
306                       builder->GetShape(out_backprop));
307 
308   int64 in_depth = input_shape.dimensions(feature_dim),
309         filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
310         feature_group_count =
311             attrs.depthwise ? filter_in_depth : in_depth / filter_in_depth;
312 
313   xla::Shape grouped_filter_shape =
314       attrs.depthwise ? GroupedFilterShapeForDepthwiseConvolution(filter_shape)
315                       : filter_shape;
316   // Reuse dimension computation logic from conv_grad_shape_utils.cc.
317   ConvBackpropDimensions dims;
318   TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
319       type_string, attrs.num_spatial_dims, input_shape, grouped_filter_shape,
320       out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
321       attrs.data_format, &dims, attrs.explicit_paddings));
322 
323   // The input gradients are computed by a convolution of the output
324   // gradients and the filter, with some appropriate padding. See the
325   // comment at the top of conv_grad_shape_utils.h for details.
326 
327   xla::ConvolutionDimensionNumbers dnums;
328   dnums.set_input_batch_dimension(batch_dim);
329   dnums.set_output_batch_dimension(batch_dim);
330   dnums.set_input_feature_dimension(feature_dim);
331   dnums.set_output_feature_dimension(feature_dim);
332 
333   // TF filter shape is [ H, W, ..., inC, outC ]
334   // Transpose the input and output features for computing the gradient.
335   dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
336   dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
337 
338   std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
339   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
340   std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
341   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
342   std::vector<int64> ones(attrs.num_spatial_dims, 1);
343   xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
344   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
345     int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
346     if (out_backprop_shape.is_dynamic_dimension(dim)) {
347       TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
348           << "Dynamic convolution only supports valid and same padding";
349       if (attrs.padding == VALID) {
350         padding_type = xla::PaddingType::PADDING_VALID;
351       }
352       if (attrs.padding == SAME) {
353         padding_type = xla::PaddingType::PADDING_SAME;
354       }
355     }
356     dnums.add_input_spatial_dimensions(dim);
357     dnums.add_kernel_spatial_dimensions(i);
358     dnums.add_output_spatial_dimensions(dim);
359 
360     kernel_spatial_dims[i] = i;
361     padding[i] = {dims.spatial_dims[i].pad_before,
362                   dims.spatial_dims[i].pad_after};
363     lhs_dilation[i] = dims.spatial_dims[i].stride;
364     rhs_dilation[i] = attrs.dilations[dim];
365   }
366 
367   if (feature_group_count != 1 && !attrs.depthwise) {
368     filter = TransposeFilterForGroupConvolutionBackpropInput(
369         filter, filter_shape, feature_group_count, attrs.num_spatial_dims);
370   }
371   // Mirror the filter in the spatial dimensions.
372   filter = xla::Rev(filter, kernel_spatial_dims);
373   if (padding_type != xla::PaddingType::PADDING_INVALID) {
374     TF_RET_CHECK(input_sizes != nullptr);
375     return xla::DynamicConvInputGrad(
376         *input_sizes, out_backprop, filter, /*window_strides=*/ones, padding,
377         lhs_dilation, rhs_dilation, dnums,
378         /*feature_group_count=*/
379         feature_group_count,
380         /*batch_group_count=*/1, precision_config, padding_type);
381   }
382   // activation gradients
383   //   = gradients (with padding and dilation) <conv> mirrored_weights
384   return xla::ConvGeneralDilated(out_backprop, filter, /*window_strides=*/ones,
385                                  padding, lhs_dilation, rhs_dilation, dnums,
386                                  /*feature_group_count=*/
387                                  feature_group_count,
388                                  /*batch_group_count=*/1, precision_config);
389 }
390 
MakeXlaBackpropFilterConvOp(StringPiece type_string,xla::XlaOp activations,const xla::Shape & filter_shape,xla::XlaOp gradients,const ConvOpAttrs & attrs,const xla::PrecisionConfig * precision_config)391 xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
392     StringPiece type_string, xla::XlaOp activations,
393     const xla::Shape& filter_shape, xla::XlaOp gradients,
394     const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
395   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
396 
397   auto* builder = activations.builder();
398   TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
399                       builder->GetShape(activations));
400   TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
401                       builder->GetShape(gradients));
402   xla::XlaOp filter_backprop;
403 
404   xla::Shape input_shape = activations_shape;
405   xla::Shape output_shape = out_backprop_shape;
406 
407   TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
408   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
409   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
410   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
411 
412   const xla::Shape grouped_filter_shape =
413       attrs.depthwise ? GroupedFilterShapeForDepthwiseConvolution(filter_shape)
414                       : filter_shape;
415   // Reuse dimension computation logic from conv_grad_shape_utils.cc.
416   ConvBackpropDimensions dims;
417   // The filter gradients are computed by a convolution of the input
418   // activations and the output gradients, with some appropriate padding.
419   // See the comment at the top of conv_grad_shape_utils.h for details.
420   xla::ConvolutionDimensionNumbers dnums;
421 
422   TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
423       type_string, attrs.num_spatial_dims, activations_shape,
424       grouped_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
425       attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
426 
427   // Obtain some useful dimensions:
428   // The last two dimensions of the filter are the input and output shapes.
429   int num_dims = attrs.num_spatial_dims + 2;
430   int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
431   int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
432   int64 in_depth = input_shape.dimensions(c_dim),
433         filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
434         batch_group_count =
435             attrs.depthwise ? filter_in_depth : in_depth / filter_in_depth;
436 
437   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
438   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
439   std::vector<int64> window_strides(attrs.num_spatial_dims);
440   std::vector<int64> ones(attrs.num_spatial_dims, 1);
441 
442   // Swap n_dim and c_dim in the activations.
443   dnums.set_input_batch_dimension(c_dim);
444   dnums.set_input_feature_dimension(n_dim);
445 
446   // The gradients become the RHS of the convolution.
447   // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
448   // where the batch becomes the input feature for the convolution.
449   dnums.set_kernel_input_feature_dimension(n_dim);
450   dnums.set_kernel_output_feature_dimension(c_dim);
451 
452   dnums.set_output_batch_dimension(attrs.num_spatial_dims);
453   dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
454 
455   // Tensorflow filter shape is [ H, W, ..., inC, outC ].
456   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
457     dnums.add_output_spatial_dimensions(i);
458   }
459   xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
460   for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
461     int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
462     if (activations_shape.is_dynamic_dimension(dim)) {
463       TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
464           << "Dynamic convolution only supports valid and same padding";
465       if (attrs.padding == VALID) {
466         padding_type = xla::PaddingType::PADDING_VALID;
467       }
468       if (attrs.padding == SAME) {
469         padding_type = xla::PaddingType::PADDING_SAME;
470       }
471     }
472     dnums.add_input_spatial_dimensions(dim);
473     dnums.add_kernel_spatial_dimensions(dim);
474     rhs_dilation[i] = dims.spatial_dims[i].stride;
475     window_strides[i] = attrs.dilations[dim];
476 
477     // We will also need to pad the input with zeros such that after the
478     // convolution, we get the right size for the filter.
479     // The padded_in_rows should be such that when we convolve this with the
480     // expanded_out_rows as a filter, we should get filter_rows back.
481 
482     const int64 padded_in_size =
483         dims.spatial_dims[i].expanded_output_size +
484         (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
485 
486     // However it can be smaller than input_rows: in this
487     // case it means some of the inputs are not used.
488     //
489     // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
490     //
491     // INPUT =  [ A  B  C ]
492     //
493     // FILTER = [ x y ]
494     //
495     // and the output will only have one column: a = A * x + B * y
496     //
497     // and input "C" is not used at all.
498     //
499     // We apply negative padding in this case.
500     const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
501 
502     // + For the EXPLICIT padding, we pad the top/left side with the explicit
503     //   padding and pad the bottom/right side with the remaining space.
504     // + For the VALID padding, we don't pad anything on the top/left side
505     //   and pad the bottom/right side with the remaining space.
506     // + For the SAME padding, we pad top/left side the same as bottom/right
507     //   side.
508     //
509     // In addition, if the padded input size is smaller than the input size,
510     // we need to ignore some training elements of the input. We do this by
511     // applying negative padding on the right/bottom.
512     const int64 pad_before = attrs.padding == Padding::EXPLICIT
513                                  ? attrs.explicit_paddings[2 * dim]
514                                  : attrs.padding == Padding::SAME
515                                        ? std::max<int64>(pad_total / 2, 0)
516                                        : 0;
517     padding[i] = {pad_before, pad_total - pad_before};
518   }
519 
520   // Besides padding the input, we will also expand output_rows to
521   //    expanded_out_rows = (output_rows - 1) * stride + 1
522   // with zeros in between:
523   //
524   //      a . . . b . . . c . . . d . . . e
525   //
526   // This is done by specifying the window dilation factors in the
527   // convolution HLO below.
528   if (padding_type != xla::PaddingType::PADDING_INVALID) {
529     filter_backprop = xla::DynamicConvKernelGrad(
530         activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
531         rhs_dilation, dnums,
532         /*feature_group_count=*/1,
533         /*batch_group_count=*/batch_group_count, precision_config,
534         padding_type);
535   } else {
536     filter_backprop = xla::ConvGeneralDilated(
537         activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
538         rhs_dilation, dnums,
539         /*feature_group_count=*/1,
540         /*batch_group_count=*/batch_group_count, precision_config);
541   }
542 
543   if (attrs.depthwise) {
544     filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions());
545   }
546 
547   return filter_backprop;
548 }
549 
550 }  // namespace tensorflow
551