1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/shape_inference.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/window_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30 namespace gpu {
31
32 namespace {
IsForwardConvolutionCanonical(const HloInstruction & conv)33 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
34 CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
35 conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
36 return window_util::HasSymmetricPadding(conv.window()) &&
37 !window_util::HasNegativePadding(conv.window()) &&
38 !window_util::HasDilation(conv.window());
39 }
40
41 // If the (positive and negative) padding on the input operand of a convolution
42 // can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
43 // dilation), returns kPad and/or kSlice instructions that explicitly apply the
44 // padding; otherwise returns the original input operand. When there is both
45 // positive padding (including dilation) and negative padding, we insert both
46 // kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
47 // into a kPad or kSlice op.
MaybePaddedAndSlicedInput(Window * conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * input)48 HloInstruction* MaybePaddedAndSlicedInput(
49 Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
50 HloInstruction* input) {
51 HloComputation* computation = input->parent();
52 if (!window_util::HasSymmetricPadding(*conv_window) ||
53 window_util::HasBaseDilation(*conv_window)) {
54 // If padding is uneven or has dilation, we insert a kPad instruction that
55 // applies positive padding and dilation.
56 //
57 // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
58 // moving all the padding into an explicit pad op, we should keep as much
59 // padding inside of cudnn as possible, on the assumption that padding
60 // within cudnn is basically free, whereas a kPad's cost increases as the
61 // amount of padding increases.
62 PaddingConfig padding_config =
63 MakeNoPaddingConfig(input->shape().dimensions_size());
64 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
65 int64 dim = conv_dnums.input_spatial_dimensions(i);
66 if (conv_window->dimensions(i).padding_low() > 0) {
67 padding_config.mutable_dimensions(dim)->set_edge_padding_low(
68 conv_window->dimensions(i).padding_low());
69 conv_window->mutable_dimensions(i)->set_padding_low(0);
70 }
71 if (conv_window->dimensions(i).padding_high() > 0) {
72 padding_config.mutable_dimensions(dim)->set_edge_padding_high(
73 conv_window->dimensions(i).padding_high());
74 conv_window->mutable_dimensions(i)->set_padding_high(0);
75 }
76 if (conv_window->dimensions(i).base_dilation() != 1) {
77 padding_config.mutable_dimensions(dim)->set_interior_padding(
78 conv_window->dimensions(i).base_dilation() - 1);
79 conv_window->mutable_dimensions(i)->set_base_dilation(1);
80 }
81 }
82 PrimitiveType element_type = input->shape().element_type();
83 HloInstruction* padding = computation->AddInstruction(
84 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
85 input = MakePadHlo(input, padding, padding_config).ValueOrDie();
86 }
87
88 if (window_util::HasNegativePadding(*conv_window)) {
89 // If the window has negative padding, insert a kSlice that explicitly
90 // applies negative padding.
91 //
92 // For each dimension, initialize the start index to 0 and the limit index
93 // to the size of that dimension.
94 std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
95 std::vector<int64> limit_indices(input->shape().dimensions().begin(),
96 input->shape().dimensions().end());
97 std::vector<int64> strides(input->shape().dimensions_size(), 1);
98 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
99 int64 dim = conv_dnums.input_spatial_dimensions(i);
100 // If dimension "dim" has negative padding, increase the start index or
101 // decrement the limit index by the amount of negative padding.
102 if (conv_window->dimensions(i).padding_low() < 0) {
103 start_indices[dim] += -conv_window->dimensions(i).padding_low();
104 conv_window->mutable_dimensions(i)->set_padding_low(0);
105 }
106 if (conv_window->dimensions(i).padding_high() < 0) {
107 limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
108 conv_window->mutable_dimensions(i)->set_padding_high(0);
109 }
110 }
111
112 input =
113 MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
114 }
115
116 return input;
117 }
118
119 // If the padding on the kernel operand of a convolution can't be folded into a
120 // cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
121 // explicitly applies the padding; otherwise returns the original kernel
122 // operand.
MaybePaddedKernel(const Window & conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * kernel)123 HloInstruction* MaybePaddedKernel(const Window& conv_window,
124 const ConvolutionDimensionNumbers& conv_dnums,
125 HloInstruction* kernel) {
126 if (!window_util::HasWindowDilation(conv_window)) {
127 return kernel;
128 }
129
130 // Compute the shape and padding config of the pad to be inserted.
131 PaddingConfig padding_config;
132 for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
133 padding_config.add_dimensions();
134 }
135 for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
136 int64 dim = conv_dnums.kernel_spatial_dimensions(i);
137 padding_config.mutable_dimensions(dim)->set_interior_padding(
138 conv_window.dimensions(i).window_dilation() - 1);
139 }
140
141 HloComputation* computation = kernel->parent();
142 PrimitiveType element_type = kernel->shape().element_type();
143 HloInstruction* padding = computation->AddInstruction(
144 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
145 return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
146 }
147 } // namespace
148
CanonicalizeForwardConvolution(HloInstruction * conv)149 bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution(
150 HloInstruction* conv) {
151 if (IsForwardConvolutionCanonical(*conv)) {
152 return false;
153 }
154
155 // Insert slices and/or pads between the convolution and its input and/or
156 // kernel operand.
157 Window new_conv_window = conv->window();
158 HloInstruction* new_input = MaybePaddedAndSlicedInput(
159 &new_conv_window, conv->convolution_dimension_numbers(),
160 conv->mutable_operand(0));
161 HloInstruction* new_kernel =
162 MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
163 conv->mutable_operand(1));
164
165 // Remove the window dilation from convolution's window field. These paddings
166 // are made explicit with the pads inserted by MaybePaddedKernel().
167 for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
168 WindowDimension* dim = new_conv_window.mutable_dimensions(i);
169
170 // The size of the kernel may have changed so update the Window to match.
171 dim->set_size(new_kernel->shape().dimensions(
172 conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
173 dim->set_window_dilation(1);
174 }
175
176 // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
177 // out the shape of conv_result.
178 VLOG(1) << "Canonicalizing forward conv";
179 std::vector<HloInstruction*> operands(conv->operands().begin(),
180 conv->operands().end());
181 operands[0] = new_input;
182 operands[1] = new_kernel;
183 auto new_conv = conv->parent()->AddInstruction(
184 conv->CloneWithNewOperands(conv->shape(), operands));
185 new_conv->set_window(new_conv_window);
186 VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
187 << new_conv->ToString();
188 TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
189 return true;
190 }
191
192 namespace {
IncreasePaddingLowBy(int64 delta,WindowDimension * window_dim)193 void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
194 window_dim->set_padding_low(window_dim->padding_low() + delta);
195 }
196
IncreasePaddingHighBy(int64 delta,WindowDimension * window_dim)197 void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
198 window_dim->set_padding_high(window_dim->padding_high() + delta);
199 }
200 } // namespace
201
CanonicalizeBackwardFilterConvolution(HloInstruction * backward_conv)202 bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
203 HloInstruction* backward_conv) {
204 CHECK_EQ(backward_conv->custom_call_target(),
205 kCudnnConvBackwardFilterCallTarget);
206 if (window_util::HasSymmetricPadding(backward_conv->window())) {
207 return false;
208 }
209
210 // A backward filter convolution with uneven padding can be canonicalized to
211 // one with even padding by padding the activations (input) beforehand. For
212 // example,
213 // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
214 // is equivalent to
215 // ABCD0 = Pad(ABCD, padding_high=1)
216 // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
217 // We choose the lesser of padding_low and padding_high as the new padding.
218 HloInstruction* input = backward_conv->mutable_operand(0);
219 Window new_backward_conv_window = backward_conv->window();
220 // input_padding_config is the config of the kPad to be inserted.
221 PaddingConfig input_padding_config =
222 MakeNoPaddingConfig(input->shape().rank());
223 ConvolutionDimensionNumbers backward_conv_dnums =
224 backward_conv->convolution_dimension_numbers();
225 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
226 int64 padding_low = backward_conv->window().dimensions(i).padding_low();
227 int64 padding_high = backward_conv->window().dimensions(i).padding_high();
228 if (padding_low < 0 || padding_high < 0) {
229 // TODO(b/32744257): The following canonicalization wouldn't remove
230 // negative padding in a backward convolution, and would therefore cause
231 // cuDNN convolution (which doesn't support negative padding) to fail.
232 return false;
233 }
234 // Compute the new, even padding for the backward conv operation.
235 int64 new_conv_padding = std::min(padding_low, padding_high);
236 int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
237 input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
238 padding_low - new_conv_padding);
239 input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
240 padding_high - new_conv_padding);
241
242 // Since we move some padding from the backward convolution to the kPad, we
243 // need to accordingly reduce the padding amount of the backward convolution
244 // and its inner forward convolution.
245 auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
246 new_dim->set_padding_low(new_conv_padding);
247 new_dim->set_padding_high(new_conv_padding);
248 }
249
250 // Create a new backward convolution replacing the old one.
251 HloComputation* computation = backward_conv->parent();
252 HloInstruction* output = backward_conv->mutable_operand(1);
253 HloInstruction* padding =
254 computation->AddInstruction(HloInstruction::CreateConstant(
255 LiteralUtil::Zero(input->shape().element_type())));
256 HloInstruction* padded_input =
257 MakePadHlo(input, padding, input_padding_config).ValueOrDie();
258
259 // The shape of the backward_conv CustomCall is a tuple (conv_result,
260 // scratch_buffer). Extract out the shape of conv_result.
261 HloInstruction* new_backward_conv =
262 computation->AddInstruction(backward_conv->CloneWithNewOperands(
263 backward_conv->shape(), {padded_input, output}));
264 new_backward_conv->set_window(new_backward_conv_window);
265
266 VLOG(1) << "Canonicalizing backward filter conv";
267 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
268 << new_backward_conv->ToString();
269
270 TF_CHECK_OK(
271 computation->ReplaceInstruction(backward_conv, new_backward_conv));
272 return true;
273 }
274
CanonicalizeBackwardInputConvolution(HloInstruction * backward_conv)275 bool CudnnConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
276 HloInstruction* backward_conv) {
277 if (window_util::HasSymmetricPadding(backward_conv->window())) {
278 return false;
279 }
280
281 Window new_backward_conv_window = backward_conv->window();
282 ConvolutionDimensionNumbers backward_conv_dnums =
283 backward_conv->convolution_dimension_numbers();
284
285 // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
286 // Get the shape of conv_result.
287 Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
288
289 Shape new_backward_conv_shape = backward_conv_shape;
290 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
291 int64 padding_low = backward_conv->window().dimensions(i).padding_low();
292 int64 padding_high = backward_conv->window().dimensions(i).padding_high();
293 if (padding_low < 0 || padding_high < 0) {
294 // TODO(b/32744257): The following canonicalization wouldn't remove
295 // negative padding in a backward convolution, and would therefore cause
296 // cuDNN convolution (which doesn't support negative padding) to fail.
297 return false;
298 }
299 // If the backward convolution has uneven padding on the activations, we
300 // move some padding on the larger end to "internal" padding, so that the
301 // backward convolution produces larger activations which get sliced later.
302 //
303 // For example, suppose we have a non-canonical HLO
304 // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
305 // where the amount of padding low is larger, we can canonicalize it to
306 // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
307 // [A] = Slice([B A])
308 if (padding_low > padding_high) {
309 IncreasePaddingLowBy(padding_high - padding_low,
310 new_backward_conv_window.mutable_dimensions(i));
311 } else if (padding_low < padding_high) {
312 IncreasePaddingHighBy(padding_low - padding_high,
313 new_backward_conv_window.mutable_dimensions(i));
314 }
315 // Decreasing the padding by X *increases* the size of our output by X.
316 int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
317 new_backward_conv_shape.set_dimensions(
318 dim, new_backward_conv_shape.dimensions(dim) +
319 std::abs(padding_low - padding_high));
320 }
321
322 // Create a new backward convolution replacing the old one.
323 HloComputation* computation = backward_conv->parent();
324 HloInstruction* output = backward_conv->mutable_operand(0);
325 HloInstruction* filter = backward_conv->mutable_operand(1);
326
327 HloInstruction* new_backward_conv_call =
328 computation->AddInstruction(backward_conv->CloneWithNewOperands(
329 ShapeUtil::MakeTupleShape(
330 {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
331 {output, filter}));
332 new_backward_conv_call->set_window(new_backward_conv_window);
333
334 // The CustomCall created above returns a tuple (conv_result, scratch_memory).
335 // Extract out the two elements.
336 HloInstruction* new_backward_conv =
337 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
338 new_backward_conv_shape, new_backward_conv_call, 0));
339 HloInstruction* new_backward_conv_scratch =
340 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
341 new_backward_conv_call->shape().tuple_shapes(1),
342 new_backward_conv_call, 1));
343
344 // Slice the new backward convolution.
345 //
346 // Initialize start_indices and limit_indices as no slicing.
347 std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
348 0LL);
349 std::vector<int64> limit_indices(
350 new_backward_conv->shape().dimensions().begin(),
351 new_backward_conv->shape().dimensions().end());
352 std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
353 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
354 int64 padding_low = backward_conv->window().dimensions(i).padding_low();
355 int64 padding_high = backward_conv->window().dimensions(i).padding_high();
356 int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
357 if (padding_low > padding_high) {
358 // If the amount of low padding (of the old backward convolution) is
359 // larger, we internally pad the low end of the activations and slice
360 // internal padding out here.
361 start_indices[dim] += padding_low - padding_high;
362 } else if (padding_low < padding_high) {
363 // If the amount of high padding is larger, we slice out the internal
364 // padding on the high end.
365 limit_indices[dim] -= padding_high - padding_low;
366 }
367 }
368
369 // Replace the old backward convolution with the slice.
370 Shape slice_shape =
371 ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
372 limit_indices, strides)
373 .ConsumeValueOrDie();
374 CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
375 << ShapeUtil::HumanString(slice_shape) << " vs "
376 << ShapeUtil::HumanString(backward_conv_shape);
377
378 HloInstruction* slice = computation->AddInstruction(
379 HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
380 start_indices, limit_indices, strides));
381 HloInstruction* new_tuple = computation->AddInstruction(
382 HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
383
384 VLOG(1) << "Canonicalizing backward input conv";
385 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
386 << new_tuple->ToString();
387
388 TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
389 return true;
390 }
391
RunOnComputation(HloComputation * computation)392 StatusOr<bool> CudnnConvPaddingLegalization::RunOnComputation(
393 HloComputation* computation) {
394 bool changed = false;
395 std::vector<HloCustomCallInstruction*> convs;
396 for (auto* instr : computation->instructions()) {
397 if (IsCustomCallToDnnConvolution(*instr)) {
398 convs.push_back(Cast<HloCustomCallInstruction>(instr));
399 }
400 }
401 for (HloCustomCallInstruction* instruction : convs) {
402 TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
403 changed |= [&] {
404 switch (kind) {
405 case CudnnConvKind::kForward:
406 case CudnnConvKind::kForwardActivation:
407 return CanonicalizeForwardConvolution(instruction);
408 case CudnnConvKind::kBackwardInput:
409 return CanonicalizeBackwardInputConvolution(instruction);
410 case CudnnConvKind::kBackwardFilter:
411 return CanonicalizeBackwardFilterConvolution(instruction);
412 }
413 }();
414 }
415 return changed;
416 }
417
Run(HloModule * module)418 StatusOr<bool> CudnnConvPaddingLegalization::Run(HloModule* module) {
419 bool changed = false;
420 for (HloComputation* computation : module->MakeNonfusionComputations()) {
421 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
422 changed |= result;
423 }
424 return changed;
425 }
426
427 } // namespace gpu
428 } // namespace xla
429