1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
22 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/window_util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 
33 namespace xla {
34 namespace gpu {
35 
36 using se::dnn::DataLayout;
37 using se::dnn::FilterLayout;
38 
39 // Returns (input, filter, output) layouts.
40 static std::tuple<DataLayout, FilterLayout, DataLayout>
HeuristicLayoutAssignment(const HloInstruction * instr,se::StreamExecutor * stream_executor)41 HeuristicLayoutAssignment(const HloInstruction* instr,
42                           se::StreamExecutor* stream_executor) {
43   // DataLayout and FilterLayout uses weird enum names. Translations:
44   //   N <=> Batch or Output
45   //   C <=> Depth or Input
46   //   H <=> Y
47   //   W <=> X
48   //
49   // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
50   //
51   // If you have trouble keeping these straight, consider that all that matters
52   // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
53 
54   constexpr auto kAllNCHW =
55       std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
56                       DataLayout::kBatchDepthYX);
57   constexpr auto kAllNHWC =
58       std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
59                       DataLayout::kBatchYXDepth);
60 
61   // Integer convolution must use NHWC.
62   if (primitive_util::IsIntegralType(
63           instr->operand(0)->shape().element_type())) {
64     return kAllNHWC;
65   }
66 
67   const DebugOptions& debug_options =
68       instr->GetModule()->config().debug_options();
69 
70   if (debug_options.xla_gpu_force_conv_nchw()) {
71     VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
72     return kAllNCHW;
73   }
74 
75   if (debug_options.xla_gpu_force_conv_nhwc()) {
76     VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
77     return kAllNHWC;
78   }
79 
80   // If we're not Volta or not fp16, or not conv2D, the decision is easy: Use
81   // NCHW.
82   if (instr->operand(0)->shape().element_type() != xla::PrimitiveType::F16 ||
83       !IsVoltaOrLater(*stream_executor) ||
84       instr->shape().tuple_shapes(0).dimensions_size() != 4) {
85     return kAllNCHW;
86   }
87 
88   VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
89 
90   // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
91   // convs with stride are significantly faster with NCHW layouts.
92   //
93   // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
94   // which on paper gives good performance. However, there are two observations:
95   // * a mixed layout combination is more cuDNN-bug prone, based on empirical
96   //   evidence.
97   // * we've also observed that for mixed layouts, cuDNN transposes data back
98   //   and forth from a different layout combination. If we end up with
99   //   transposes anyway, we prefer to have them in XLA, as they can be fused.
100   if (auto* dnn = stream_executor->AsDnn()) {
101     auto version_status = dnn->GetVersion();
102     if (version_status.ok()) {
103       auto version = version_status.ConsumeValueOrDie();
104       if (std::make_tuple(version.major_version(), version.minor_version()) <=
105               std::make_tuple(7, 3) &&
106           instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
107           window_util::HasStride(instr->window())) {
108         return kAllNCHW;
109       }
110     }
111   }
112 
113   // For other Volta f16 convolutions, use NHWC.
114   return kAllNHWC;
115 }
116 
117 // Adds layout constraints on the cudnn custom-call instruction. The layout
118 // constraints are represented in terms of minor_to_major fields of both
119 // operands and the output shape. Depending on the underlying algorithm, one of
120 // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
AddBackendConstraintsToDnnConvCustomCall(HloCustomCallInstruction * instr,LayoutConstraints * constraints)121 Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
122     HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
123   Shape lhs_shape = instr->operand(0)->shape();
124   Shape rhs_shape = instr->operand(1)->shape();
125   Shape result_shape = instr->shape().tuple_shapes(0);
126 
127   Shape* input_shape;
128   Shape* filter_shape;
129   Shape* output_shape;
130 
131   TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
132   switch (kind) {
133     case CudnnConvKind::kForward:
134     case CudnnConvKind::kForwardActivation:
135       input_shape = &lhs_shape;
136       filter_shape = &rhs_shape;
137       output_shape = &result_shape;
138       break;
139     case CudnnConvKind::kBackwardInput:
140       input_shape = &result_shape;
141       filter_shape = &rhs_shape;
142       output_shape = &lhs_shape;
143       break;
144     case CudnnConvKind::kBackwardFilter:
145       input_shape = &lhs_shape;
146       filter_shape = &result_shape;
147       output_shape = &rhs_shape;
148       break;
149   }
150 
151   {
152     DataLayout input;
153     FilterLayout filter;
154     DataLayout output;
155     std::tie(input, filter, output) =
156         HeuristicLayoutAssignment(instr, stream_executor_);
157 
158     TF_ASSIGN_OR_RETURN(
159         std::tie(*input_shape->mutable_layout(),
160                  *filter_shape->mutable_layout(),
161                  *output_shape->mutable_layout()),
162         StreamExecutorConvLayoutsToXlaLayouts(
163             instr->convolution_dimension_numbers(), input, filter, output));
164   }
165 
166   // The custom call returns a tuple of (actual_result, scratch_buffer);
167   // call_result_buf is the logical buffer for actual_result, the thing that
168   // contains the result of the conv call.
169   TF_ASSIGN_OR_RETURN(const LogicalBuffer* call_result_buf,
170                       constraints->points_to_analysis().GetBufferDefinedAt(
171                           instr, /*index=*/{0}));
172 
173   // Set layouts of the instructions' shapes.
174   TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0));
175   TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1));
176   TF_RETURN_IF_ERROR(
177       constraints->SetBufferLayout(result_shape.layout(), *call_result_buf));
178   // instr->operand(2), if exists, is the bias buffer. There is no need to
179   // assign layout to it, as it has only one dimension.
180 
181   // instr->operand(3), if exists, is the side input buffer.
182   if (instr->operand_count() == 4) {
183     if (kind != CudnnConvKind::kForwardActivation) {
184       return InternalError(
185           "Invalid convolution. Conv has a side input, but kind is not fused "
186           "conv forward: %s",
187           instr->ToString());
188     }
189     // The side input layout must match the output layout.
190     TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3));
191   }
192   return Status::OK();
193 }
194 
AddBackendConstraints(LayoutConstraints * constraints)195 Status GpuLayoutAssignment::AddBackendConstraints(
196     LayoutConstraints* constraints) {
197   // Add convolution constraints in reverse postorder that the earliest
198   // convolution layout propagates first. This reduces the likelihood of fusion
199   // nodes with copies.
200   auto post_order = constraints->computation()->MakeInstructionPostOrder();
201   for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
202        ++iterator) {
203     HloInstruction* instruction = *iterator;
204     if (IsCustomCallToDnnConvolution(*instruction)) {
205       TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
206           Cast<HloCustomCallInstruction>(instruction), constraints));
207     }
208 
209     CHECK(!IsCublasGemm(*instruction))
210         << "Gemm rewriting should run after layout assignment";
211 
212     // For batched dot we require the default layout.
213     // TODO(b/112111608): This is overly conservative, the only real restriction
214     // is that batch dimensions must be major.
215     if (IsMatrixMultiplication(*instruction) &&
216         instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
217       // Verify that the batch dims come before the row and col dims.
218       DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers();
219       CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
220                dim_nums.rhs_batch_dimensions_size());
221       CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
222                instruction->shape().rank());
223       for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
224         CHECK_LT(batch_dim, instruction->shape().rank() - 2);
225       }
226 
227       // Set both inputs and the output to default layout.
228       Shape op0_shape = instruction->operand(0)->shape();
229       LayoutUtil::SetToDefaultLayout(&op0_shape);
230       Shape op1_shape = instruction->operand(1)->shape();
231       LayoutUtil::SetToDefaultLayout(&op1_shape);
232       Shape output_shape = instruction->shape();
233       LayoutUtil::SetToDefaultLayout(&output_shape);
234       TF_RETURN_IF_ERROR(
235           constraints->SetOperandLayout(op0_shape, instruction, 0));
236       TF_RETURN_IF_ERROR(
237           constraints->SetOperandLayout(op1_shape, instruction, 1));
238       TF_RETURN_IF_ERROR(
239           constraints->SetInstructionLayout(output_shape, instruction));
240     } else if (instruction->opcode() == HloOpcode::kFft) {
241       // cuFFT requires a dim0 major layout.
242       Shape op0_shape = instruction->operand(0)->shape();
243       LayoutUtil::SetToDefaultLayout(&op0_shape);
244       Shape output_shape = instruction->shape();
245       LayoutUtil::SetToDefaultLayout(&output_shape);
246       TF_RETURN_IF_ERROR(
247           constraints->SetOperandLayout(op0_shape, instruction, 0));
248       TF_RETURN_IF_ERROR(
249           constraints->SetInstructionLayout(output_shape, instruction));
250     } else if (instruction->opcode() == HloOpcode::kSort &&
251                instruction->operand(0)->shape().rank() > 1) {
252       // Make sure that all the operands and the output(s) have the same layout.
253       Shape keys_shape = instruction->operand(0)->shape();
254       Layout keys_layout =
255           LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
256       for (int64 i = 0; i < instruction->operand_count(); ++i) {
257         Shape shape = instruction->operand(i)->shape();
258         *shape.mutable_layout() = keys_layout;
259         TF_RETURN_IF_ERROR(
260             constraints->SetOperandLayout(shape, instruction, i));
261         const LogicalBuffer* output_buffer;
262         if (instruction->shape().IsArray()) {
263           TF_ASSIGN_OR_RETURN(
264               output_buffer,
265               constraints->points_to_analysis().GetBufferDefinedAt(instruction,
266                                                                    {}));
267         } else {
268           TF_ASSIGN_OR_RETURN(
269               output_buffer,
270               constraints->points_to_analysis().GetBufferDefinedAt(instruction,
271                                                                    {i}));
272         }
273         TF_RETURN_IF_ERROR(
274             constraints->SetBufferLayout(keys_layout, *output_buffer));
275       }
276     } else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
277       // TODO(phawkins): Ideally we would relax this constraint. What we
278       // actually want is that:
279       // a) the batch dimensions are major, in no particular order.
280       // b) the two minor dimensions are in fortran (column-major) order,
281       // although for the 'a' argument we could potentially accept row-major
282       // order and fold the transpose into the operator.
283       auto set_fortran_layout = [](Shape* shape) {
284         LayoutUtil::SetToDefaultLayout(shape);
285         int n = shape->mutable_layout()->minor_to_major_size();
286         CHECK_GE(n, 2);
287         std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
288                   shape->mutable_layout()->mutable_minor_to_major()->at(1));
289       };
290       Shape op0_shape = instruction->operand(0)->shape();
291       Shape op1_shape = instruction->operand(1)->shape();
292       Shape output_shape = instruction->shape();
293       set_fortran_layout(&op0_shape);
294       set_fortran_layout(&op1_shape);
295       set_fortran_layout(&output_shape);
296       TF_RETURN_IF_ERROR(
297           constraints->SetOperandLayout(op0_shape, instruction, 0));
298       TF_RETURN_IF_ERROR(
299           constraints->SetOperandLayout(op1_shape, instruction, 1));
300       TF_RETURN_IF_ERROR(
301           constraints->SetInstructionLayout(output_shape, instruction));
302     }
303   }
304   return Status::OK();
305 }
306 
PropagateOperandConstraint(const OperandLayoutConstraint & layout_constraint,LayoutConstraints * constraints)307 Status GpuLayoutAssignment::PropagateOperandConstraint(
308     const OperandLayoutConstraint& layout_constraint,
309     LayoutConstraints* constraints) {
310   const HloInstruction* instruction = layout_constraint.instruction();
311 
312   // cudnn batchnorm forward inference's result must have the same layout as its
313   // operand 0.
314   if (instruction->opcode() == HloOpcode::kCustomCall &&
315       instruction->custom_call_target() ==
316           kCudnnBatchNormForwardInferenceCallTarget &&
317       layout_constraint.operand_no() == 0) {
318     TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
319         layout_constraint.shape_layout().shape(), instruction));
320   }
321 
322   // cudnn batchnorm forward training returns a tuple {output, mean,
323   // inverse-stddev}.  mean and inverse-stddev are rank 1 and so have only one
324   // possible layout, but output is not (necessarily) rank 1, and, like in
325   // batchnorm forward inference, must have the same layout as operand 0.
326   if (instruction->opcode() == HloOpcode::kCustomCall &&
327       instruction->custom_call_target() ==
328           kCudnnBatchNormForwardTrainingCallTarget &&
329       layout_constraint.operand_no() == 0) {
330     TF_ASSIGN_OR_RETURN(const LogicalBuffer* out_buf,
331                         constraints->points_to_analysis().GetBufferDefinedAt(
332                             instruction, /*index=*/{0}));
333     TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
334         layout_constraint.shape_layout().layout(), *out_buf));
335   }
336 
337   // Like forward training, cudnn batchnorm backward returns a tuple {output,
338   // mean, inverse-stddev}, and its operand 0 and 'output' must have the same
339   // layout.  In addition, its operand 0 and operand 4 -- the 'operand' and
340   // 'grad_output' parameters -- must have the same layout.
341   if (instruction->opcode() == HloOpcode::kCustomCall &&
342       instruction->custom_call_target() == kCudnnBatchNormBackwardCallTarget &&
343       (layout_constraint.operand_no() == 0 ||
344        layout_constraint.operand_no() == 4)) {
345     TF_ASSIGN_OR_RETURN(const LogicalBuffer* out_buf,
346                         constraints->points_to_analysis().GetBufferDefinedAt(
347                             instruction, /*index=*/{0}));
348     TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
349         layout_constraint.shape_layout().layout(), *out_buf));
350 
351     int64 operand_to_set = layout_constraint.operand_no() == 0 ? 4 : 0;
352     TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
353         layout_constraint.shape_layout().shape(), instruction, operand_to_set));
354   }
355 
356   return LayoutAssignment::PropagateOperandConstraint(layout_constraint,
357                                                       constraints);
358 }
359 
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)360 Status GpuLayoutAssignment::PropagateBufferConstraint(
361     const BufferLayoutConstraint& buffer_constraint,
362     LayoutConstraints* constraints) {
363   const LogicalBuffer& buf = buffer_constraint.buffer();
364   const HloInstruction* instruction = buf.instruction();
365 
366   Shape shape_with_layout = buf.shape();
367   *shape_with_layout.mutable_layout() = buffer_constraint.layout();
368 
369   // Propagate output constraints to the operands of cudnn batchnorm ops.  This
370   // is the same as PropagateOperandConstraint, just in the other direction.  We
371   // need to both to fulfill our contract to LayoutAssignment.
372   if (instruction->opcode() == HloOpcode::kCustomCall &&
373       instruction->custom_call_target() ==
374           kCudnnBatchNormForwardInferenceCallTarget) {
375     TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
376         shape_with_layout, instruction, /*operand_no=*/0));
377   }
378 
379   if (instruction->opcode() == HloOpcode::kCustomCall &&
380       instruction->custom_call_target() ==
381           kCudnnBatchNormForwardTrainingCallTarget &&
382       buf.index() == ShapeIndex({0})) {
383     TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
384         shape_with_layout, instruction, /*operand_no=*/0));
385   }
386   if (instruction->opcode() == HloOpcode::kCustomCall &&
387       instruction->custom_call_target() == kCudnnBatchNormBackwardCallTarget &&
388       buf.index() == ShapeIndex({0})) {
389     // batchnorm backward has two operands, "operand" and "grad_output" whose
390     // layouts must both match that of the result at tuple-index 0.
391     TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
392         shape_with_layout, instruction, /*operand_no=*/0));
393     TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
394         shape_with_layout, instruction, /*operand_no=*/4));
395   }
396 
397   return LayoutAssignment::PropagateBufferConstraint(buffer_constraint,
398                                                      constraints);
399 }
400 
401 }  // namespace gpu
402 }  // namespace xla
403