1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/shape_inference.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/window_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 namespace {
IsForwardConvolutionCanonical(const HloInstruction & conv)33 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
34   CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
35         conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
36   return window_util::HasSymmetricPadding(conv.window()) &&
37          !window_util::HasNegativePadding(conv.window()) &&
38          !window_util::HasDilation(conv.window());
39 }
40 
41 // If the (positive and negative) padding on the input operand of a convolution
42 // can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
43 // dilation), returns kPad and/or kSlice instructions that explicitly apply the
44 // padding; otherwise returns the original input operand. When there is both
45 // positive padding (including dilation) and negative padding, we insert both
46 // kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
47 // into a kPad or kSlice op.
MaybePaddedAndSlicedInput(Window * conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * input)48 HloInstruction* MaybePaddedAndSlicedInput(
49     Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
50     HloInstruction* input) {
51   HloComputation* computation = input->parent();
52   if (!window_util::HasSymmetricPadding(*conv_window) ||
53       window_util::HasBaseDilation(*conv_window)) {
54     // If padding is uneven or has dilation, we insert a kPad instruction that
55     // applies positive padding and dilation.
56     //
57     // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
58     // moving all the padding into an explicit pad op, we should keep as much
59     // padding inside of cudnn as possible, on the assumption that padding
60     // within cudnn is basically free, whereas a kPad's cost increases as the
61     // amount of padding increases.
62     PaddingConfig padding_config =
63         MakeNoPaddingConfig(input->shape().dimensions_size());
64     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
65       int64 dim = conv_dnums.input_spatial_dimensions(i);
66       if (conv_window->dimensions(i).padding_low() > 0) {
67         padding_config.mutable_dimensions(dim)->set_edge_padding_low(
68             conv_window->dimensions(i).padding_low());
69         conv_window->mutable_dimensions(i)->set_padding_low(0);
70       }
71       if (conv_window->dimensions(i).padding_high() > 0) {
72         padding_config.mutable_dimensions(dim)->set_edge_padding_high(
73             conv_window->dimensions(i).padding_high());
74         conv_window->mutable_dimensions(i)->set_padding_high(0);
75       }
76       if (conv_window->dimensions(i).base_dilation() != 1) {
77         padding_config.mutable_dimensions(dim)->set_interior_padding(
78             conv_window->dimensions(i).base_dilation() - 1);
79         conv_window->mutable_dimensions(i)->set_base_dilation(1);
80       }
81     }
82     PrimitiveType element_type = input->shape().element_type();
83     HloInstruction* padding = computation->AddInstruction(
84         HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
85     input = MakePadHlo(input, padding, padding_config).ValueOrDie();
86   }
87 
88   if (window_util::HasNegativePadding(*conv_window)) {
89     // If the window has negative padding, insert a kSlice that explicitly
90     // applies negative padding.
91     //
92     // For each dimension, initialize the start index to 0 and the limit index
93     // to the size of that dimension.
94     std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
95     std::vector<int64> limit_indices(input->shape().dimensions().begin(),
96                                      input->shape().dimensions().end());
97     std::vector<int64> strides(input->shape().dimensions_size(), 1);
98     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
99       int64 dim = conv_dnums.input_spatial_dimensions(i);
100       // If dimension "dim" has negative padding, increase the start index or
101       // decrement the limit index by the amount of negative padding.
102       if (conv_window->dimensions(i).padding_low() < 0) {
103         start_indices[dim] += -conv_window->dimensions(i).padding_low();
104         conv_window->mutable_dimensions(i)->set_padding_low(0);
105       }
106       if (conv_window->dimensions(i).padding_high() < 0) {
107         limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
108         conv_window->mutable_dimensions(i)->set_padding_high(0);
109       }
110     }
111 
112     input =
113         MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
114   }
115 
116   return input;
117 }
118 
119 // If the padding on the kernel operand of a convolution can't be folded into a
120 // cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
121 // explicitly applies the padding; otherwise returns the original kernel
122 // operand.
MaybePaddedKernel(const Window & conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * kernel)123 HloInstruction* MaybePaddedKernel(const Window& conv_window,
124                                   const ConvolutionDimensionNumbers& conv_dnums,
125                                   HloInstruction* kernel) {
126   if (!window_util::HasWindowDilation(conv_window)) {
127     return kernel;
128   }
129 
130   // Compute the shape and padding config of the pad to be inserted.
131   PaddingConfig padding_config;
132   for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
133     padding_config.add_dimensions();
134   }
135   for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
136     int64 dim = conv_dnums.kernel_spatial_dimensions(i);
137     padding_config.mutable_dimensions(dim)->set_interior_padding(
138         conv_window.dimensions(i).window_dilation() - 1);
139   }
140 
141   HloComputation* computation = kernel->parent();
142   PrimitiveType element_type = kernel->shape().element_type();
143   HloInstruction* padding = computation->AddInstruction(
144       HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
145   return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
146 }
147 }  // namespace
148 
CanonicalizeForwardConvolution(HloInstruction * conv)149 bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution(
150     HloInstruction* conv) {
151   if (IsForwardConvolutionCanonical(*conv)) {
152     return false;
153   }
154 
155   // Insert slices and/or pads between the convolution and its input and/or
156   // kernel operand.
157   Window new_conv_window = conv->window();
158   HloInstruction* new_input = MaybePaddedAndSlicedInput(
159       &new_conv_window, conv->convolution_dimension_numbers(),
160       conv->mutable_operand(0));
161   HloInstruction* new_kernel =
162       MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
163                         conv->mutable_operand(1));
164 
165   // Remove the window dilation from convolution's window field. These paddings
166   // are made explicit with the pads inserted by MaybePaddedKernel().
167   for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
168     WindowDimension* dim = new_conv_window.mutable_dimensions(i);
169 
170     // The size of the kernel may have changed so update the Window to match.
171     dim->set_size(new_kernel->shape().dimensions(
172         conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
173     dim->set_window_dilation(1);
174   }
175 
176   // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
177   // out the shape of conv_result.
178   VLOG(1) << "Canonicalizing forward conv";
179   std::vector<HloInstruction*> operands(conv->operands().begin(),
180                                         conv->operands().end());
181   operands[0] = new_input;
182   operands[1] = new_kernel;
183   auto new_conv = conv->parent()->AddInstruction(
184       conv->CloneWithNewOperands(conv->shape(), operands));
185   new_conv->set_window(new_conv_window);
186   VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
187           << new_conv->ToString();
188   TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
189   return true;
190 }
191 
192 namespace {
IncreasePaddingLowBy(int64 delta,WindowDimension * window_dim)193 void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
194   window_dim->set_padding_low(window_dim->padding_low() + delta);
195 }
196 
IncreasePaddingHighBy(int64 delta,WindowDimension * window_dim)197 void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
198   window_dim->set_padding_high(window_dim->padding_high() + delta);
199 }
200 }  // namespace
201 
CanonicalizeBackwardFilterConvolution(HloInstruction * backward_conv)202 bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
203     HloInstruction* backward_conv) {
204   CHECK_EQ(backward_conv->custom_call_target(),
205            kCudnnConvBackwardFilterCallTarget);
206   if (window_util::HasSymmetricPadding(backward_conv->window())) {
207     return false;
208   }
209 
210   // A backward filter convolution with uneven padding can be canonicalized to
211   // one with even padding by padding the activations (input) beforehand. For
212   // example,
213   //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
214   // is equivalent to
215   //   ABCD0 = Pad(ABCD, padding_high=1)
216   //   BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
217   // We choose the lesser of padding_low and padding_high as the new padding.
218   HloInstruction* input = backward_conv->mutable_operand(0);
219   Window new_backward_conv_window = backward_conv->window();
220   // input_padding_config is the config of the kPad to be inserted.
221   PaddingConfig input_padding_config =
222       MakeNoPaddingConfig(input->shape().rank());
223   ConvolutionDimensionNumbers backward_conv_dnums =
224       backward_conv->convolution_dimension_numbers();
225   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
226     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
227     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
228     if (padding_low < 0 || padding_high < 0) {
229       // TODO(b/32744257): The following canonicalization wouldn't remove
230       // negative padding in a backward convolution, and would therefore cause
231       // cuDNN convolution (which doesn't support negative padding) to fail.
232       return false;
233     }
234     // Compute the new, even padding for the backward conv operation.
235     int64 new_conv_padding = std::min(padding_low, padding_high);
236     int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
237     input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
238         padding_low - new_conv_padding);
239     input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
240         padding_high - new_conv_padding);
241 
242     // Since we move some padding from the backward convolution to the kPad, we
243     // need to accordingly reduce the padding amount of the backward convolution
244     // and its inner forward convolution.
245     auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
246     new_dim->set_padding_low(new_conv_padding);
247     new_dim->set_padding_high(new_conv_padding);
248   }
249 
250   // Create a new backward convolution replacing the old one.
251   HloComputation* computation = backward_conv->parent();
252   HloInstruction* output = backward_conv->mutable_operand(1);
253   HloInstruction* padding =
254       computation->AddInstruction(HloInstruction::CreateConstant(
255           LiteralUtil::Zero(input->shape().element_type())));
256   HloInstruction* padded_input =
257       MakePadHlo(input, padding, input_padding_config).ValueOrDie();
258 
259   // The shape of the backward_conv CustomCall is a tuple (conv_result,
260   // scratch_buffer).  Extract out the shape of conv_result.
261   HloInstruction* new_backward_conv =
262       computation->AddInstruction(backward_conv->CloneWithNewOperands(
263           backward_conv->shape(), {padded_input, output}));
264   new_backward_conv->set_window(new_backward_conv_window);
265 
266   VLOG(1) << "Canonicalizing backward filter conv";
267   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
268           << new_backward_conv->ToString();
269 
270   TF_CHECK_OK(
271       computation->ReplaceInstruction(backward_conv, new_backward_conv));
272   return true;
273 }
274 
CanonicalizeBackwardInputConvolution(HloInstruction * backward_conv)275 bool CudnnConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
276     HloInstruction* backward_conv) {
277   if (window_util::HasSymmetricPadding(backward_conv->window())) {
278     return false;
279   }
280 
281   Window new_backward_conv_window = backward_conv->window();
282   ConvolutionDimensionNumbers backward_conv_dnums =
283       backward_conv->convolution_dimension_numbers();
284 
285   // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
286   // Get the shape of conv_result.
287   Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
288 
289   Shape new_backward_conv_shape = backward_conv_shape;
290   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
291     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
292     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
293     if (padding_low < 0 || padding_high < 0) {
294       // TODO(b/32744257): The following canonicalization wouldn't remove
295       // negative padding in a backward convolution, and would therefore cause
296       // cuDNN convolution (which doesn't support negative padding) to fail.
297       return false;
298     }
299     // If the backward convolution has uneven padding on the activations, we
300     // move some padding on the larger end to "internal" padding, so that the
301     // backward convolution produces larger activations which get sliced later.
302     //
303     // For example, suppose we have a non-canonical HLO
304     //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
305     // where the amount of padding low is larger, we can canonicalize it to
306     //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
307     //   [A] = Slice([B A])
308     if (padding_low > padding_high) {
309       IncreasePaddingLowBy(padding_high - padding_low,
310                            new_backward_conv_window.mutable_dimensions(i));
311     } else if (padding_low < padding_high) {
312       IncreasePaddingHighBy(padding_low - padding_high,
313                             new_backward_conv_window.mutable_dimensions(i));
314     }
315     // Decreasing the padding by X *increases* the size of our output by X.
316     int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
317     new_backward_conv_shape.set_dimensions(
318         dim, new_backward_conv_shape.dimensions(dim) +
319                  std::abs(padding_low - padding_high));
320   }
321 
322   // Create a new backward convolution replacing the old one.
323   HloComputation* computation = backward_conv->parent();
324   HloInstruction* output = backward_conv->mutable_operand(0);
325   HloInstruction* filter = backward_conv->mutable_operand(1);
326 
327   HloInstruction* new_backward_conv_call =
328       computation->AddInstruction(backward_conv->CloneWithNewOperands(
329           ShapeUtil::MakeTupleShape(
330               {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
331           {output, filter}));
332   new_backward_conv_call->set_window(new_backward_conv_window);
333 
334   // The CustomCall created above returns a tuple (conv_result, scratch_memory).
335   // Extract out the two elements.
336   HloInstruction* new_backward_conv =
337       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
338           new_backward_conv_shape, new_backward_conv_call, 0));
339   HloInstruction* new_backward_conv_scratch =
340       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
341           new_backward_conv_call->shape().tuple_shapes(1),
342           new_backward_conv_call, 1));
343 
344   // Slice the new backward convolution.
345   //
346   // Initialize start_indices and limit_indices as no slicing.
347   std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
348                                    0LL);
349   std::vector<int64> limit_indices(
350       new_backward_conv->shape().dimensions().begin(),
351       new_backward_conv->shape().dimensions().end());
352   std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
353   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
354     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
355     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
356     int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
357     if (padding_low > padding_high) {
358       // If the amount of low padding (of the old backward convolution) is
359       // larger, we internally pad the low end of the activations and slice
360       // internal padding out here.
361       start_indices[dim] += padding_low - padding_high;
362     } else if (padding_low < padding_high) {
363       // If the amount of high padding is larger, we slice out the internal
364       // padding on the high end.
365       limit_indices[dim] -= padding_high - padding_low;
366     }
367   }
368 
369   // Replace the old backward convolution with the slice.
370   Shape slice_shape =
371       ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
372                                       limit_indices, strides)
373           .ConsumeValueOrDie();
374   CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
375       << ShapeUtil::HumanString(slice_shape) << " vs "
376       << ShapeUtil::HumanString(backward_conv_shape);
377 
378   HloInstruction* slice = computation->AddInstruction(
379       HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
380                                   start_indices, limit_indices, strides));
381   HloInstruction* new_tuple = computation->AddInstruction(
382       HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
383 
384   VLOG(1) << "Canonicalizing backward input conv";
385   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
386           << new_tuple->ToString();
387 
388   TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
389   return true;
390 }
391 
RunOnComputation(HloComputation * computation)392 StatusOr<bool> CudnnConvPaddingLegalization::RunOnComputation(
393     HloComputation* computation) {
394   bool changed = false;
395   std::vector<HloCustomCallInstruction*> convs;
396   for (auto* instr : computation->instructions()) {
397     if (IsCustomCallToDnnConvolution(*instr)) {
398       convs.push_back(Cast<HloCustomCallInstruction>(instr));
399     }
400   }
401   for (HloCustomCallInstruction* instruction : convs) {
402     TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
403     changed |= [&] {
404       switch (kind) {
405         case CudnnConvKind::kForward:
406         case CudnnConvKind::kForwardActivation:
407           return CanonicalizeForwardConvolution(instruction);
408         case CudnnConvKind::kBackwardInput:
409           return CanonicalizeBackwardInputConvolution(instruction);
410         case CudnnConvKind::kBackwardFilter:
411           return CanonicalizeBackwardFilterConvolution(instruction);
412       }
413     }();
414   }
415   return changed;
416 }
417 
Run(HloModule * module)418 StatusOr<bool> CudnnConvPaddingLegalization::Run(HloModule* module) {
419   bool changed = false;
420   for (HloComputation* computation : module->MakeNonfusionComputations()) {
421     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
422     changed |= result;
423   }
424   return changed;
425 }
426 
427 }  // namespace gpu
428 }  // namespace xla
429