1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA-specific Ops for 2D convolution.
17
18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/numeric_op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/ops_util.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_slice.h"
37 #include "tensorflow/core/kernels/conv_grad_ops.h"
38 #include "tensorflow/core/util/padding.h"
39 #include "tensorflow/core/util/tensor_format.h"
40
41 namespace tensorflow {
42 namespace {
43
44 // Returns the expanded size of a filter used for depthwise convolution.
45 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape & shape)46 xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
47 int num_dims = shape.dimensions_size();
48 CHECK_GE(num_dims, 2); // Crash OK
49 xla::Shape expanded_shape = shape;
50 expanded_shape.set_dimensions(
51 num_dims - 1,
52 shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
53 return expanded_shape;
54 }
55
56 // Create a mask for depthwise convolution that will make a normal convolution
57 // produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
58 // depthwise filter this returns a [2, 2, 3, 6] tensor
59 // 1 1 0 0 0 0 1 1 0 0 0 0
60 // 0 0 1 1 0 0 0 0 1 1 0 0
61 // 0 0 0 0 1 1 0 0 0 0 1 1
62 //
63 // 1 1 0 0 0 0 1 1 0 0 0 0
64 // 0 0 1 1 0 0 0 0 1 1 0 0
65 // 0 0 0 0 1 1 0 0 0 0 1 1
66 //
67 // The first step is to create a iota A with iota_dimension = 2
68 // 0 0 0 0 0 0 0 0 0 0 0 0
69 // 1 1 1 1 1 1 1 1 1 1 1 1
70 // 2 2 2 2 2 2 2 2 2 2 2 2
71 //
72 // 0 0 0 0 0 0 0 0 0 0 0 0
73 // 1 1 1 1 1 1 1 1 1 1 1 1
74 // 2 2 2 2 2 2 2 2 2 2 2 2
75 //
76 // and another iota B with iota_dimension = 3
77 // 0 1 2 3 4 5 0 1 2 3 4 5
78 // 0 1 2 3 4 5 0 1 2 3 4 5
79 // 0 1 2 3 4 5 0 1 2 3 4 5
80 //
81 // 0 1 2 3 4 5 0 1 2 3 4 5
82 // 0 1 2 3 4 5 0 1 2 3 4 5
83 // 0 1 2 3 4 5 0 1 2 3 4 5
84 //
85 // and divide B by 2 to get
86 // 0 0 1 1 2 2 0 0 1 1 2 2
87 // 0 0 1 1 2 2 0 0 1 1 2 2
88 // 0 0 1 1 2 2 0 0 1 1 2 2
89 //
90 // 0 0 1 1 2 2 0 0 1 1 2 2
91 // 0 0 1 1 2 2 0 0 1 1 2 2
92 // 0 0 1 1 2 2 0 0 1 1 2 2
93 //
94 // Finally compare A and B and return the result at the beginning of the
95 // comment.
CreateExpandedFilterMask(const xla::Shape & filter_shape,xla::XlaBuilder * builder)96 xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
97 xla::XlaBuilder* builder) {
98 xla::Shape expanded_filter_shape =
99 ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
100 int64 depthwise_multiplier =
101 filter_shape.dimensions(filter_shape.dimensions_size() - 1);
102
103 // Create two iotas with the shape of the expanded filter, one of them with
104 // the iota dimension chosen as the feature dimension, and the other a iota
105 // with the iota dimension chosen as the expanded output feature dimension.
106 std::vector<int64> iota_dimensions(expanded_filter_shape.dimensions().begin(),
107 expanded_filter_shape.dimensions().end());
108 xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions);
109 xla::XlaOp input_feature_iota = xla::Iota(
110 builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2);
111 xla::XlaOp expanded_feature_iota = xla::Iota(
112 builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1);
113
114 // Divide 'expanded_feature_iota' by the depthwise_multiplier to create
115 // [0 0 1 1 2 2] ... in the example in the function comment.
116 expanded_feature_iota =
117 xla::Div(expanded_feature_iota,
118 XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
119 depthwise_multiplier));
120
121 // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a
122 // diagonal predicate.
123 return xla::Eq(expanded_feature_iota, input_feature_iota);
124 }
125
126 // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
127 // build a depthwise convolution.
ReshapeFilterForDepthwiseConvolution(const xla::Shape & filter_shape,const xla::XlaOp & filter)128 xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
129 const xla::XlaOp& filter) {
130 int64 input_feature_dim = filter_shape.dimensions_size() - 2;
131 int64 output_feature_dim = filter_shape.dimensions_size() - 1;
132 int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
133 int64 input_feature = filter_shape.dimensions(input_feature_dim);
134
135 // Create a [H, W, ..., 1, N*M] reshape of the filter.
136 xla::Shape implicit_broadcast_filter_shape = filter_shape;
137 implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
138 implicit_broadcast_filter_shape.set_dimensions(
139 output_feature_dim, depthwise_multiplier * input_feature);
140 return xla::Reshape(
141 filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
142 }
143
144 // Reduces the results of the convolution with an expanded filter to the
145 // non-expanded filter.
ContractFilterForDepthwiseBackprop(const xla::Shape & filter_shape,const xla::XlaOp & filter_backprop,xla::XlaBuilder * builder)146 xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
147 const xla::XlaOp& filter_backprop,
148 xla::XlaBuilder* builder) {
149 auto masked_expanded_filter =
150 xla::Select(CreateExpandedFilterMask(filter_shape, builder),
151 filter_backprop, xla::ZerosLike(filter_backprop));
152
153 auto elem_type = filter_shape.element_type();
154 return xla::Reshape(
155 // This reduce does not need inputs to be converted with
156 // XlaHelpers::SumAccumulationType() since the select above guarantees
157 // that only one element is non zero, so there cannot be accumulated
158 // precision error.
159 xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
160 CreateScalarAddComputation(elem_type, builder),
161 {filter_shape.dimensions_size() - 2}),
162 xla::AsInt64Slice(filter_shape.dimensions()));
163 }
164
165 // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
166 // convolutions (as currently implemented).
CheckConvAttrs(const ConvOpAttrs & attrs)167 Status CheckConvAttrs(const ConvOpAttrs& attrs) {
168 const int num_dims = attrs.num_spatial_dims + 2;
169 if (attrs.strides.size() != num_dims) {
170 return errors::InvalidArgument("Sliding window strides field must specify ",
171 num_dims, " dimensions");
172 }
173 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
174 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
175 if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
176 return errors::Unimplemented(
177 "Current implementation does not yet support strides in the batch and "
178 "depth dimensions.");
179 }
180 if (attrs.dilations.size() != num_dims) {
181 return errors::InvalidArgument("Dilations field must specify ", num_dims,
182 " dimensions");
183 }
184 if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
185 return errors::Unimplemented(
186 "Current implementation does not support dilations in the batch and "
187 "depth dimensions.");
188 }
189 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
190 int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
191 if (attrs.dilations[input_dim] < 1) {
192 return errors::Unimplemented("Dilation values must be positive; ", i,
193 "th spatial dimension had dilation ",
194 attrs.dilations[input_dim]);
195 }
196 }
197 return Status::OK();
198 }
199
200 // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
201 // to TensorShapes.
ConvBackpropComputeDimensionsV2XlaShapes(StringPiece label,int num_spatial_dims,const xla::Shape & input_shape,const xla::Shape & filter_shape,const xla::Shape & out_backprop_shape,absl::Span<const int32> dilations,const std::vector<int32> & strides,Padding padding,TensorFormat data_format,ConvBackpropDimensions * dims,absl::Span<const int64> explicit_paddings)202 Status ConvBackpropComputeDimensionsV2XlaShapes(
203 StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
204 const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
205 absl::Span<const int32> dilations, const std::vector<int32>& strides,
206 Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
207 absl::Span<const int64> explicit_paddings) {
208 TensorShape input_tensor_shape, filter_tensor_shape,
209 out_backprop_tensor_shape;
210 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
211 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
212 TF_RETURN_IF_ERROR(
213 XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
214 return ConvBackpropComputeDimensionsV2(
215 label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
216 out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
217 data_format, dims);
218 }
219
220 } // anonymous namespace
221
Create(int num_spatial_dims,bool depthwise,OpKernelConstruction * ctx)222 xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
223 bool depthwise,
224 OpKernelConstruction* ctx) {
225 ConvOpAttrs attrs;
226 attrs.num_spatial_dims = num_spatial_dims;
227 attrs.depthwise = depthwise;
228 TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
229 TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
230 TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
231 if (attrs.padding == EXPLICIT) {
232 TF_RETURN_IF_ERROR(
233 ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
234 }
235
236 string data_format;
237 TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
238 if (!FormatFromString(data_format, &attrs.data_format)) {
239 return errors::InvalidArgument("Invalid data format: ", data_format);
240 }
241
242 return attrs;
243 }
244
MakeXlaForwardConvOp(StringPiece,xla::XlaOp conv_input,xla::XlaOp filter,const ConvOpAttrs & attrs)245 xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
246 xla::XlaOp conv_input,
247 xla::XlaOp filter,
248 const ConvOpAttrs& attrs) {
249 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
250
251 auto* builder = conv_input.builder();
252 TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
253 // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
254 TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
255
256 // For 2D convolution, there should be 4 dimensions.
257 int num_dims = attrs.num_spatial_dims + 2;
258 if (input_shape.dimensions_size() != num_dims) {
259 return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
260 input_shape.DebugString());
261 }
262 if (filter_shape.dimensions_size() != num_dims) {
263 return errors::InvalidArgument(
264 "filter must be ", num_dims,
265 "-dimensional: ", filter_shape.DebugString());
266 }
267
268 // The last two dimensions of the filter are the input and output shapes.
269 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
270 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
271
272 int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
273 // The 'C' dimension for input is in_depth. It must be the same as
274 // the filter's in_depth.
275 if (in_depth != input_shape.dimensions(feature_dim)) {
276 return errors::InvalidArgument(
277 "input and filter must have the same depth: ", in_depth, " vs ",
278 input_shape.dimensions(feature_dim));
279 }
280
281 if (attrs.depthwise) {
282 filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
283 }
284
285 xla::ConvolutionDimensionNumbers dims;
286 std::vector<int64> window_strides(attrs.num_spatial_dims);
287 std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
288 std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
289 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
290
291 dims.set_input_batch_dimension(batch_dim);
292 dims.set_output_batch_dimension(batch_dim);
293 dims.set_input_feature_dimension(feature_dim);
294 dims.set_output_feature_dimension(feature_dim);
295 dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
296 dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
297
298 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
299 const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
300 dims.add_input_spatial_dimensions(dim);
301 dims.add_kernel_spatial_dimensions(i);
302 dims.add_output_spatial_dimensions(dim);
303 window_strides[i] = attrs.strides.at(dim);
304 rhs_dilation[i] = attrs.dilations.at(dim);
305
306 if (attrs.padding == EXPLICIT) {
307 padding[i] = {attrs.explicit_paddings.at(dim * 2),
308 attrs.explicit_paddings.at(dim * 2 + 1)};
309 }
310
311 int64 unused_output_size;
312 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
313 input_shape.dimensions(dim), filter_shape.dimensions(i),
314 rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
315 &padding[i].first, &padding[i].second));
316 }
317
318 return xla::ConvGeneralDilated(
319 conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
320 dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
321 }
322
MakeXlaBackpropInputConvOp(StringPiece type_string,const xla::Shape & input_shape,xla::XlaOp filter,xla::XlaOp out_backprop,const ConvOpAttrs & attrs)323 xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
324 StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
325 xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
326 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
327
328 int num_dims = attrs.num_spatial_dims + 2;
329 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
330 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
331
332 auto* builder = filter.builder();
333 TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
334 TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
335 builder->GetShape(out_backprop));
336
337 xla::Shape expanded_filter_shape =
338 attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
339 : filter_shape;
340 // Reuse dimension computation logic from conv_grad_ops.cc.
341 ConvBackpropDimensions dims;
342 TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
343 type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
344 out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
345 attrs.data_format, &dims, attrs.explicit_paddings));
346
347 // The input gradients are computed by a convolution of the output
348 // gradients and the filter, with some appropriate padding. See the
349 // comment at the top of conv_grad_ops.h for details.
350
351 xla::ConvolutionDimensionNumbers dnums;
352 dnums.set_input_batch_dimension(batch_dim);
353 dnums.set_output_batch_dimension(batch_dim);
354 dnums.set_input_feature_dimension(feature_dim);
355 dnums.set_output_feature_dimension(feature_dim);
356
357 // TF filter shape is [ H, W, ..., inC, outC ]
358 // Transpose the input and output features for computing the gradient.
359 dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
360 dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
361
362 std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
363 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
364 std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
365 std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
366 std::vector<int64> ones(attrs.num_spatial_dims, 1);
367 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
368 int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
369 dnums.add_input_spatial_dimensions(dim);
370 dnums.add_kernel_spatial_dimensions(i);
371 dnums.add_output_spatial_dimensions(dim);
372
373 kernel_spatial_dims[i] = i;
374 padding[i] = {dims.spatial_dims[i].pad_before,
375 dims.spatial_dims[i].pad_after};
376 lhs_dilation[i] = dims.spatial_dims[i].stride;
377 rhs_dilation[i] = attrs.dilations[dim];
378 }
379
380 // Mirror the filter in the spatial dimensions.
381 xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
382
383 // activation gradients
384 // = gradients (with padding and dilation) <conv> mirrored_weights
385 return xla::ConvGeneralDilated(
386 out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
387 lhs_dilation, rhs_dilation, dnums,
388 /*feature_group_count=*/
389 attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
390 filter_shape.dimensions(attrs.num_spatial_dims + 1)
391 : 1);
392 }
393
MakeXlaBackpropFilterConvOp(StringPiece type_string,xla::XlaOp activations,const xla::Shape & filter_shape,xla::XlaOp gradients,const ConvOpAttrs & attrs)394 xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
395 StringPiece type_string, xla::XlaOp activations,
396 const xla::Shape& filter_shape, xla::XlaOp gradients,
397 const ConvOpAttrs& attrs) {
398 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
399
400 auto* builder = activations.builder();
401 TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
402 builder->GetShape(activations));
403 TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
404 builder->GetShape(gradients));
405 xla::XlaOp filter_backprop;
406
407 xla::Shape input_shape = activations_shape;
408 xla::Shape output_shape = out_backprop_shape;
409
410 TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
411 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
412 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
413 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
414
415 const xla::Shape expanded_filter_shape =
416 attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
417 : filter_shape;
418 // Reuse dimension computation logic from conv_grad_ops.cc.
419 ConvBackpropDimensions dims;
420 // The filter gradients are computed by a convolution of the input
421 // activations and the output gradients, with some appropriate padding.
422 // See the comment at the top of conv_grad_ops.h for details.
423 xla::ConvolutionDimensionNumbers dnums;
424
425 TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
426 type_string, attrs.num_spatial_dims, activations_shape,
427 expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
428 attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
429
430 // The activations (inputs) form the LHS of the convolution.
431 // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
432 // For the gradient computation, we flip the roles of the batch and
433 // feature dimensions.
434 // Each spatial entry has size in_depth * batch
435
436 // The last two dimensions of the filter are the input and output shapes.
437 int num_dims = attrs.num_spatial_dims + 2;
438 int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
439 int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
440
441 bool use_batch_group_count =
442 filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
443
444 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
445 std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
446 std::vector<int64> window_strides(attrs.num_spatial_dims);
447 std::vector<int64> ones(attrs.num_spatial_dims, 1);
448
449 // Swap n_dim and c_dim in the activations.
450 dnums.set_input_batch_dimension(c_dim);
451 dnums.set_input_feature_dimension(n_dim);
452
453 // The gradients become the RHS of the convolution.
454 // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
455 // where the batch becomes the input feature for the convolution.
456 dnums.set_kernel_input_feature_dimension(n_dim);
457 dnums.set_kernel_output_feature_dimension(c_dim);
458
459 // The dimension swap below is needed because filter shape is KH,KW,F,DM.
460 if (use_batch_group_count) {
461 dnums.set_output_batch_dimension(attrs.num_spatial_dims + 1);
462 dnums.set_output_feature_dimension(attrs.num_spatial_dims);
463 } else {
464 dnums.set_output_batch_dimension(attrs.num_spatial_dims);
465 dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
466 }
467
468 // Tensorflow filter shape is [ H, W, ..., inC, outC ].
469 for (int i = 0; i < attrs.num_spatial_dims; ++i) {
470 dnums.add_output_spatial_dimensions(i);
471 }
472
473 for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
474 int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
475 dnums.add_input_spatial_dimensions(dim);
476 dnums.add_kernel_spatial_dimensions(dim);
477 rhs_dilation[i] = dims.spatial_dims[i].stride;
478 window_strides[i] = attrs.dilations[dim];
479
480 // We will also need to pad the input with zeros such that after the
481 // convolution, we get the right size for the filter.
482 // The padded_in_rows should be such that when we convolve this with the
483 // expanded_out_rows as a filter, we should get filter_rows back.
484
485 const int64 padded_in_size =
486 dims.spatial_dims[i].expanded_output_size +
487 (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
488
489 // However it can be smaller than input_rows: in this
490 // case it means some of the inputs are not used.
491 //
492 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
493 //
494 // INPUT = [ A B C ]
495 //
496 // FILTER = [ x y ]
497 //
498 // and the output will only have one column: a = A * x + B * y
499 //
500 // and input "C" is not used at all.
501 //
502 // We apply negative padding in this case.
503 const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
504
505 // + For the EXPLICIT padding, we pad the top/left side with the explicit
506 // padding and pad the bottom/right side with the remaining space.
507 // + For the VALID padding, we don't pad anything on the top/left side
508 // and pad the bottom/right side with the remaining space.
509 // + For the SAME padding, we pad top/left side the same as bottom/right
510 // side.
511 //
512 // In addition, if the padded input size is smaller than the input size,
513 // we need to ignore some training elements of the input. We do this by
514 // applying negative padding on the right/bottom.
515 const int64 pad_before = attrs.padding == Padding::EXPLICIT
516 ? attrs.explicit_paddings[2 * dim]
517 : attrs.padding == Padding::SAME
518 ? std::max<int64>(pad_total / 2, 0)
519 : 0;
520 padding[i] = {pad_before, pad_total - pad_before};
521 }
522
523 // Besides padding the input, we will also expand output_rows to
524 // expanded_out_rows = (output_rows - 1) * stride + 1
525 // with zeros in between:
526 //
527 // a . . . b . . . c . . . d . . . e
528 //
529 // This is done by specifying the window dilation factors in the
530 // convolution HLO below.
531
532 filter_backprop = xla::ConvGeneralDilated(
533 activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
534 rhs_dilation, dnums,
535 /*feature_group_count=*/1,
536 /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
537
538 if (!use_batch_group_count && attrs.depthwise) {
539 filter_backprop = ContractFilterForDepthwiseBackprop(
540 filter_shape, filter_backprop, activations.builder());
541 }
542
543 return filter_backprop;
544 }
545
546 } // namespace tensorflow
547