1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Ops for 2D convolution.
17 
18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/numeric_op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/ops_util.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_slice.h"
37 #include "tensorflow/core/kernels/conv_grad_ops.h"
38 #include "tensorflow/core/util/padding.h"
39 #include "tensorflow/core/util/tensor_format.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 // Returns the expanded size of a filter used for depthwise convolution.
45 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape & shape)46 xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
47   int num_dims = shape.dimensions_size();
48   CHECK_GE(num_dims, 2);  // Crash OK
49   xla::Shape expanded_shape = shape;
50   expanded_shape.set_dimensions(
51       num_dims - 1,
52       shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
53   return expanded_shape;
54 }
55 
56 // Create a mask for depthwise convolution that will make a normal convolution
57 // produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
58 // depthwise filter this returns a [2, 2, 3, 6] tensor
59 //   1 1 0 0 0 0   1 1 0 0 0 0
60 //   0 0 1 1 0 0   0 0 1 1 0 0
61 //   0 0 0 0 1 1   0 0 0 0 1 1
62 //
63 //   1 1 0 0 0 0   1 1 0 0 0 0
64 //   0 0 1 1 0 0   0 0 1 1 0 0
65 //   0 0 0 0 1 1   0 0 0 0 1 1
66 //
67 // The first step is to create a iota A with iota_dimension = 2
68 //   0 0 0 0 0 0   0 0 0 0 0 0
69 //   1 1 1 1 1 1   1 1 1 1 1 1
70 //   2 2 2 2 2 2   2 2 2 2 2 2
71 //
72 //   0 0 0 0 0 0   0 0 0 0 0 0
73 //   1 1 1 1 1 1   1 1 1 1 1 1
74 //   2 2 2 2 2 2   2 2 2 2 2 2
75 //
76 // and another iota B with iota_dimension = 3
77 //   0 1 2 3 4 5  0 1 2 3 4 5
78 //   0 1 2 3 4 5  0 1 2 3 4 5
79 //   0 1 2 3 4 5  0 1 2 3 4 5
80 //
81 //   0 1 2 3 4 5  0 1 2 3 4 5
82 //   0 1 2 3 4 5  0 1 2 3 4 5
83 //   0 1 2 3 4 5  0 1 2 3 4 5
84 //
85 // and divide B by 2 to get
86 //   0 0 1 1 2 2  0 0 1 1 2 2
87 //   0 0 1 1 2 2  0 0 1 1 2 2
88 //   0 0 1 1 2 2  0 0 1 1 2 2
89 //
90 //   0 0 1 1 2 2  0 0 1 1 2 2
91 //   0 0 1 1 2 2  0 0 1 1 2 2
92 //   0 0 1 1 2 2  0 0 1 1 2 2
93 //
94 // Finally compare A and B and return the result at the beginning of the
95 // comment.
CreateExpandedFilterMask(const xla::Shape & filter_shape,xla::XlaBuilder * builder)96 xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
97                                     xla::XlaBuilder* builder) {
98   xla::Shape expanded_filter_shape =
99       ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
100   int64 depthwise_multiplier =
101       filter_shape.dimensions(filter_shape.dimensions_size() - 1);
102 
103   // Create two iotas with the shape of the expanded filter, one of them with
104   // the iota dimension chosen as the feature dimension, and the other a iota
105   // with the iota dimension chosen as the expanded output feature dimension.
106   std::vector<int64> iota_dimensions(expanded_filter_shape.dimensions().begin(),
107                                      expanded_filter_shape.dimensions().end());
108   xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions);
109   xla::XlaOp input_feature_iota = xla::Iota(
110       builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2);
111   xla::XlaOp expanded_feature_iota = xla::Iota(
112       builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1);
113 
114   // Divide 'expanded_feature_iota' by the depthwise_multiplier to create
115   // [0 0 1 1 2 2] ... in the example in the function comment.
116   expanded_feature_iota =
117       xla::Div(expanded_feature_iota,
118                XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
119                                           depthwise_multiplier));
120 
121   // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a
122   // diagonal predicate.
123   return xla::Eq(expanded_feature_iota, input_feature_iota);
124 }
125 
126 // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
127 // build a depthwise convolution.
ReshapeFilterForDepthwiseConvolution(const xla::Shape & filter_shape,const xla::XlaOp & filter)128 xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
129                                                 const xla::XlaOp& filter) {
130   int64 input_feature_dim = filter_shape.dimensions_size() - 2;
131   int64 output_feature_dim = filter_shape.dimensions_size() - 1;
132   int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
133   int64 input_feature = filter_shape.dimensions(input_feature_dim);
134 
135   // Create a [H, W, ..., 1, N*M] reshape of the filter.
136   xla::Shape implicit_broadcast_filter_shape = filter_shape;
137   implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
138   implicit_broadcast_filter_shape.set_dimensions(
139       output_feature_dim, depthwise_multiplier * input_feature);
140   return xla::Reshape(
141       filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
142 }
143 
144 // Reduces the results of the convolution with an expanded filter to the
145 // non-expanded filter.
ContractFilterForDepthwiseBackprop(const xla::Shape & filter_shape,const xla::XlaOp & filter_backprop,xla::XlaBuilder * builder)146 xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
147                                               const xla::XlaOp& filter_backprop,
148                                               xla::XlaBuilder* builder) {
149   auto masked_expanded_filter =
150       xla::Select(CreateExpandedFilterMask(filter_shape, builder),
151                   filter_backprop, xla::ZerosLike(filter_backprop));
152 
153   auto elem_type = filter_shape.element_type();
154   return xla::Reshape(
155       // This reduce does not need inputs to be converted with
156       // XlaHelpers::SumAccumulationType() since the select above guarantees
157       // that only one element is non zero, so there cannot be accumulated
158       // precision error.
159       xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
160                   CreateScalarAddComputation(elem_type, builder),
161                   {filter_shape.dimensions_size() - 2}),
162       xla::AsInt64Slice(filter_shape.dimensions()));
163 }
164 
165 // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
166 // convolutions (as currently implemented).
CheckConvAttrs(const ConvOpAttrs & attrs)167 Status CheckConvAttrs(const ConvOpAttrs& attrs) {
168   const int num_dims = attrs.num_spatial_dims + 2;
169   if (attrs.strides.size() != num_dims) {
170     return errors::InvalidArgument("Sliding window strides field must specify ",
171                                    num_dims, " dimensions");
172   }
173   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
174   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
175   if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
176     return errors::Unimplemented(
177         "Current implementation does not yet support strides in the batch and "
178         "depth dimensions.");
179   }
180   if (attrs.dilations.size() != num_dims) {
181     return errors::InvalidArgument("Dilations field must specify ", num_dims,
182                                    " dimensions");
183   }
184   if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
185     return errors::Unimplemented(
186         "Current implementation does not support dilations in the batch and "
187         "depth dimensions.");
188   }
189   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
190     int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
191     if (attrs.dilations[input_dim] < 1) {
192       return errors::Unimplemented("Dilation values must be positive; ", i,
193                                    "th spatial dimension had dilation ",
194                                    attrs.dilations[input_dim]);
195     }
196   }
197   return Status::OK();
198 }
199 
200 // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
201 // to TensorShapes.
ConvBackpropComputeDimensionsV2XlaShapes(StringPiece label,int num_spatial_dims,const xla::Shape & input_shape,const xla::Shape & filter_shape,const xla::Shape & out_backprop_shape,absl::Span<const int32> dilations,const std::vector<int32> & strides,Padding padding,TensorFormat data_format,ConvBackpropDimensions * dims,absl::Span<const int64> explicit_paddings)202 Status ConvBackpropComputeDimensionsV2XlaShapes(
203     StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
204     const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
205     absl::Span<const int32> dilations, const std::vector<int32>& strides,
206     Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
207     absl::Span<const int64> explicit_paddings) {
208   TensorShape input_tensor_shape, filter_tensor_shape,
209       out_backprop_tensor_shape;
210   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
211   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
212   TF_RETURN_IF_ERROR(
213       XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
214   return ConvBackpropComputeDimensionsV2(
215       label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
216       out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
217       data_format, dims);
218 }
219 
220 }  // anonymous namespace
221 
Create(int num_spatial_dims,bool depthwise,OpKernelConstruction * ctx)222 xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
223                                                bool depthwise,
224                                                OpKernelConstruction* ctx) {
225   ConvOpAttrs attrs;
226   attrs.num_spatial_dims = num_spatial_dims;
227   attrs.depthwise = depthwise;
228   TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
229   TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
230   TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
231   if (attrs.padding == EXPLICIT) {
232     TF_RETURN_IF_ERROR(
233         ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
234   }
235 
236   string data_format;
237   TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
238   if (!FormatFromString(data_format, &attrs.data_format)) {
239     return errors::InvalidArgument("Invalid data format: ", data_format);
240   }
241 
242   return attrs;
243 }
244 
MakeXlaForwardConvOp(StringPiece,xla::XlaOp conv_input,xla::XlaOp filter,const ConvOpAttrs & attrs)245 xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
246                                                xla::XlaOp conv_input,
247                                                xla::XlaOp filter,
248                                                const ConvOpAttrs& attrs) {
249   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
250 
251   auto* builder = conv_input.builder();
252   TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
253   // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
254   TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
255 
256   // For 2D convolution, there should be 4 dimensions.
257   int num_dims = attrs.num_spatial_dims + 2;
258   if (input_shape.dimensions_size() != num_dims) {
259     return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
260                                    input_shape.DebugString());
261   }
262   if (filter_shape.dimensions_size() != num_dims) {
263     return errors::InvalidArgument(
264         "filter must be ", num_dims,
265         "-dimensional: ", filter_shape.DebugString());
266   }
267 
268   // The last two dimensions of the filter are the input and output shapes.
269   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
270   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
271 
272   int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
273   // The 'C' dimension for input is in_depth. It must be the same as
274   // the filter's in_depth.
275   if (in_depth != input_shape.dimensions(feature_dim)) {
276     return errors::InvalidArgument(
277         "input and filter must have the same depth: ", in_depth, " vs ",
278         input_shape.dimensions(feature_dim));
279   }
280 
281   if (attrs.depthwise) {
282     filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
283   }
284 
285   xla::ConvolutionDimensionNumbers dims;
286   std::vector<int64> window_strides(attrs.num_spatial_dims);
287   std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
288   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
289   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
290 
291   dims.set_input_batch_dimension(batch_dim);
292   dims.set_output_batch_dimension(batch_dim);
293   dims.set_input_feature_dimension(feature_dim);
294   dims.set_output_feature_dimension(feature_dim);
295   dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
296   dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
297 
298   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
299     const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
300     dims.add_input_spatial_dimensions(dim);
301     dims.add_kernel_spatial_dimensions(i);
302     dims.add_output_spatial_dimensions(dim);
303     window_strides[i] = attrs.strides.at(dim);
304     rhs_dilation[i] = attrs.dilations.at(dim);
305 
306     if (attrs.padding == EXPLICIT) {
307       padding[i] = {attrs.explicit_paddings.at(dim * 2),
308                     attrs.explicit_paddings.at(dim * 2 + 1)};
309     }
310 
311     int64 unused_output_size;
312     TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
313         input_shape.dimensions(dim), filter_shape.dimensions(i),
314         rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
315         &padding[i].first, &padding[i].second));
316   }
317 
318   return xla::ConvGeneralDilated(
319       conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
320       dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
321 }
322 
MakeXlaBackpropInputConvOp(StringPiece type_string,const xla::Shape & input_shape,xla::XlaOp filter,xla::XlaOp out_backprop,const ConvOpAttrs & attrs)323 xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
324     StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
325     xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
326   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
327 
328   int num_dims = attrs.num_spatial_dims + 2;
329   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
330   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
331 
332   auto* builder = filter.builder();
333   TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
334   TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
335                       builder->GetShape(out_backprop));
336 
337   xla::Shape expanded_filter_shape =
338       attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
339                       : filter_shape;
340   // Reuse dimension computation logic from conv_grad_ops.cc.
341   ConvBackpropDimensions dims;
342   TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
343       type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
344       out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
345       attrs.data_format, &dims, attrs.explicit_paddings));
346 
347   // The input gradients are computed by a convolution of the output
348   // gradients and the filter, with some appropriate padding. See the
349   // comment at the top of conv_grad_ops.h for details.
350 
351   xla::ConvolutionDimensionNumbers dnums;
352   dnums.set_input_batch_dimension(batch_dim);
353   dnums.set_output_batch_dimension(batch_dim);
354   dnums.set_input_feature_dimension(feature_dim);
355   dnums.set_output_feature_dimension(feature_dim);
356 
357   // TF filter shape is [ H, W, ..., inC, outC ]
358   // Transpose the input and output features for computing the gradient.
359   dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
360   dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
361 
362   std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
363   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
364   std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
365   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
366   std::vector<int64> ones(attrs.num_spatial_dims, 1);
367   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
368     int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
369     dnums.add_input_spatial_dimensions(dim);
370     dnums.add_kernel_spatial_dimensions(i);
371     dnums.add_output_spatial_dimensions(dim);
372 
373     kernel_spatial_dims[i] = i;
374     padding[i] = {dims.spatial_dims[i].pad_before,
375                   dims.spatial_dims[i].pad_after};
376     lhs_dilation[i] = dims.spatial_dims[i].stride;
377     rhs_dilation[i] = attrs.dilations[dim];
378   }
379 
380   // Mirror the filter in the spatial dimensions.
381   xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
382 
383   // activation gradients
384   //   = gradients (with padding and dilation) <conv> mirrored_weights
385   return xla::ConvGeneralDilated(
386       out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
387       lhs_dilation, rhs_dilation, dnums,
388       /*feature_group_count=*/
389       attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
390                             filter_shape.dimensions(attrs.num_spatial_dims + 1)
391                       : 1);
392 }
393 
MakeXlaBackpropFilterConvOp(StringPiece type_string,xla::XlaOp activations,const xla::Shape & filter_shape,xla::XlaOp gradients,const ConvOpAttrs & attrs)394 xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
395     StringPiece type_string, xla::XlaOp activations,
396     const xla::Shape& filter_shape, xla::XlaOp gradients,
397     const ConvOpAttrs& attrs) {
398   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
399 
400   auto* builder = activations.builder();
401   TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
402                       builder->GetShape(activations));
403   TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
404                       builder->GetShape(gradients));
405   xla::XlaOp filter_backprop;
406 
407   xla::Shape input_shape = activations_shape;
408   xla::Shape output_shape = out_backprop_shape;
409 
410   TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
411   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
412   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
413   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
414 
415   const xla::Shape expanded_filter_shape =
416       attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
417                       : filter_shape;
418   // Reuse dimension computation logic from conv_grad_ops.cc.
419   ConvBackpropDimensions dims;
420   // The filter gradients are computed by a convolution of the input
421   // activations and the output gradients, with some appropriate padding.
422   // See the comment at the top of conv_grad_ops.h for details.
423   xla::ConvolutionDimensionNumbers dnums;
424 
425   TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
426       type_string, attrs.num_spatial_dims, activations_shape,
427       expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
428       attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
429 
430   // The activations (inputs) form the LHS of the convolution.
431   // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
432   // For the gradient computation, we flip the roles of the batch and
433   // feature dimensions.
434   // Each spatial entry has size in_depth * batch
435 
436   // The last two dimensions of the filter are the input and output shapes.
437   int num_dims = attrs.num_spatial_dims + 2;
438   int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
439   int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
440 
441   bool use_batch_group_count =
442       filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
443 
444   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
445   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
446   std::vector<int64> window_strides(attrs.num_spatial_dims);
447   std::vector<int64> ones(attrs.num_spatial_dims, 1);
448 
449   // Swap n_dim and c_dim in the activations.
450   dnums.set_input_batch_dimension(c_dim);
451   dnums.set_input_feature_dimension(n_dim);
452 
453   // The gradients become the RHS of the convolution.
454   // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
455   // where the batch becomes the input feature for the convolution.
456   dnums.set_kernel_input_feature_dimension(n_dim);
457   dnums.set_kernel_output_feature_dimension(c_dim);
458 
459   // The dimension swap below is needed because filter shape is KH,KW,F,DM.
460   if (use_batch_group_count) {
461     dnums.set_output_batch_dimension(attrs.num_spatial_dims + 1);
462     dnums.set_output_feature_dimension(attrs.num_spatial_dims);
463   } else {
464     dnums.set_output_batch_dimension(attrs.num_spatial_dims);
465     dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
466   }
467 
468   // Tensorflow filter shape is [ H, W, ..., inC, outC ].
469   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
470     dnums.add_output_spatial_dimensions(i);
471   }
472 
473   for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
474     int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
475     dnums.add_input_spatial_dimensions(dim);
476     dnums.add_kernel_spatial_dimensions(dim);
477     rhs_dilation[i] = dims.spatial_dims[i].stride;
478     window_strides[i] = attrs.dilations[dim];
479 
480     // We will also need to pad the input with zeros such that after the
481     // convolution, we get the right size for the filter.
482     // The padded_in_rows should be such that when we convolve this with the
483     // expanded_out_rows as a filter, we should get filter_rows back.
484 
485     const int64 padded_in_size =
486         dims.spatial_dims[i].expanded_output_size +
487         (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
488 
489     // However it can be smaller than input_rows: in this
490     // case it means some of the inputs are not used.
491     //
492     // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
493     //
494     // INPUT =  [ A  B  C ]
495     //
496     // FILTER = [ x y ]
497     //
498     // and the output will only have one column: a = A * x + B * y
499     //
500     // and input "C" is not used at all.
501     //
502     // We apply negative padding in this case.
503     const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
504 
505     // + For the EXPLICIT padding, we pad the top/left side with the explicit
506     //   padding and pad the bottom/right side with the remaining space.
507     // + For the VALID padding, we don't pad anything on the top/left side
508     //   and pad the bottom/right side with the remaining space.
509     // + For the SAME padding, we pad top/left side the same as bottom/right
510     //   side.
511     //
512     // In addition, if the padded input size is smaller than the input size,
513     // we need to ignore some training elements of the input. We do this by
514     // applying negative padding on the right/bottom.
515     const int64 pad_before = attrs.padding == Padding::EXPLICIT
516                                  ? attrs.explicit_paddings[2 * dim]
517                                  : attrs.padding == Padding::SAME
518                                        ? std::max<int64>(pad_total / 2, 0)
519                                        : 0;
520     padding[i] = {pad_before, pad_total - pad_before};
521   }
522 
523   // Besides padding the input, we will also expand output_rows to
524   //    expanded_out_rows = (output_rows - 1) * stride + 1
525   // with zeros in between:
526   //
527   //      a . . . b . . . c . . . d . . . e
528   //
529   // This is done by specifying the window dilation factors in the
530   // convolution HLO below.
531 
532   filter_backprop = xla::ConvGeneralDilated(
533       activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
534       rhs_dilation, dnums,
535       /*feature_group_count=*/1,
536       /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
537 
538   if (!use_batch_group_count && attrs.depthwise) {
539     filter_backprop = ContractFilterForDepthwiseBackprop(
540         filter_shape, filter_backprop, activations.builder());
541   }
542 
543   return filter_backprop;
544 }
545 
546 }  // namespace tensorflow
547