1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 
30 namespace xla {
31 namespace gpu {
32 namespace {
33 
34 // The amount of shared memory a CUDA kernel can use.
35 //
36 // Stay on the conservative side, this is smaller than full 64kB, but allows
37 // some extra space for cache.
38 int64 kSharedMemoryBudgetInBytes = 40000;
39 
AppendParams(const HloInstruction & instr,std::vector<HloInstruction * > * params)40 void AppendParams(const HloInstruction& instr,
41                   std::vector<HloInstruction*>* params) {
42   if (instr.opcode() == HloOpcode::kFusion) {
43     params->insert(std::end(*params), std::begin(instr.fused_parameters()),
44                    std::end(instr.fused_parameters()));
45   } else {
46     for (HloInstruction* operand : instr.operands()) {
47       params->push_back(operand);
48     }
49   }
50 }
51 
IfFusedReadsElementsMultipleTimes(const HloInstruction & instr)52 bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
53   CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
54   if (instr.opcode() == HloOpcode::kReduce &&
55       !IsReductionFromOrToContiguousDimensions(instr)) {
56     return true;
57   }
58   // Avoid fusing reduce-window when stride is less than window size to minimize
59   // the number of reads of the same elements.
60   if (instr.opcode() == HloOpcode::kReduceWindow) {
61     for (const auto& dim : instr.window().dimensions()) {
62       if (dim.size() > dim.stride()) {
63         return true;
64       }
65     }
66   }
67   return false;
68 }
69 
ExtractRelativeOrderOfNontrivialDims(const Shape & shape)70 std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
71   std::vector<int64> relative_order;
72   for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
73     if (shape.dimensions(dim) > 1) {
74       relative_order.push_back(dim);
75     }
76   }
77   // Now normalize the dimensions to values between 0 and true rank - 1.
78   std::vector<int64> sorted_dims = relative_order;
79   std::sort(sorted_dims.begin(), sorted_dims.end());
80   for (int64& dim : relative_order) {
81     int64 sorted_index = std::distance(
82         sorted_dims.begin(),
83         std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
84     dim = sorted_index;
85   }
86   return relative_order;
87 }
88 
89 }  // namespace
90 
LayoutsAreReduceInputFusionFriendly(const HloInstruction & producer,const HloInstruction & reduce)91 bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
92                                          const HloInstruction& reduce) {
93   std::vector<HloInstruction*> params;
94   AppendParams(producer, &params);
95   AppendParams(reduce, &params);
96   int64 max_true_rank = -1;
97   std::vector<int64> max_rank_order;
98   for (HloInstruction* param : params) {
99     if (param->shape().IsArray() &&
100         ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
101       max_true_rank = ShapeUtil::TrueRank(param->shape());
102       max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
103     }
104   }
105   return absl::c_all_of(params, [&](HloInstruction* param) {
106     return !param->shape().IsArray() ||
107            ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
108            ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
109                max_rank_order;
110   });
111 }
112 
IsReduceInputFusion(const HloInstruction & instr)113 bool IsReduceInputFusion(const HloInstruction& instr) {
114   if (instr.IsMultiOutputFusion()) {
115     for (const HloInstruction* operand :
116          instr.fused_expression_root()->operands()) {
117       if (IsReductionFromOrToContiguousDimensions(*operand)) {
118         CHECK(instr.IsInputFusion())
119             << " Multi-output fusion rooted at reduction-to-vector ops must be "
120                "of kind kInput: "
121             << instr.ToString();
122         return true;
123       }
124     }
125   } else if (instr.opcode() == HloOpcode::kFusion &&
126              IsReductionFromOrToContiguousDimensions(
127                  *instr.fused_expression_root())) {
128     CHECK(instr.IsInputFusion())
129         << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
130         << instr.ToString();
131     return true;
132   }
133   return false;
134 }
135 
IsInputFusibleReduction(const HloInstruction & instr)136 bool IsInputFusibleReduction(const HloInstruction& instr) {
137   // TODO(b/129089333): Don't fuse variadic reduce.
138   if (instr.opcode() == HloOpcode::kReduce && instr.shape().IsTuple()) {
139     return false;
140   }
141 
142   return IsReduceInputFusion(instr) ||
143          IsReductionFromOrToContiguousDimensions(instr);
144 }
145 
GetRealHeroForMultiOutputFusion(const HloInstruction & instr)146 const HloInstruction* GetRealHeroForMultiOutputFusion(
147     const HloInstruction& instr) {
148   if (instr.opcode() != HloOpcode::kFusion) {
149     return &instr;
150   }
151   auto fused_expression_root = instr.fused_expression_root();
152   if (!instr.IsMultiOutputFusion()) {
153     return fused_expression_root;
154   }
155   // If possible, we want to pick a reduction-from-or-to-contiguous-dims
156   // operand of the fusion root, because it has the most constraints.
157   for (const auto* inst : fused_expression_root->operands()) {
158     if (IsReductionFromOrToContiguousDimensions(*inst)) {
159       return inst;
160     }
161   }
162   return fused_expression_root->operands()[0];
163 }
164 
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)165 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
166                                           const HloInstruction& instr2) {
167   // Multi-output fusion kernels share a common parallel loop. The loop
168   // dimensions are determined by instruction shapes.
169   auto get_loop_shape = [&](const HloInstruction* element_instr) {
170     // Special-case reduction-to-vector ops: The loop dimensions are determined
171     // by the shape of the first operand.
172     if (IsReductionFromOrToContiguousDimensions(*element_instr)) {
173       return element_instr->operand(0)->shape();
174     }
175     return element_instr->shape();
176   };
177 
178   // All shapes of the root tuple of multi-output fusions should agree, i.e. all
179   // root ops should have equal output shapes. An exception are
180   // reduction-to-vector ops. Here the input shapes of the reduction (first
181   // operand shape) and the reduction dimensions need to match.
182   auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
183   auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
184   if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
185       IsReductionFromOrToContiguousDimensions(*instr_2) &&
186       !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
187     return false;
188   }
189   // The elementwise output shapes must be the same (including layout).
190   return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
191                                              get_loop_shape(instr_2));
192 }
193 
IsInputFusibleScatter(const HloInstruction & instr)194 bool IsInputFusibleScatter(const HloInstruction& instr) {
195   if (instr.opcode() == HloOpcode::kScatter ||
196       (instr.opcode() == HloOpcode::kFusion &&
197        instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
198        instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
199     return true;
200   }
201   return false;
202 }
203 
IsInputFusible(const HloInstruction & instr)204 bool IsInputFusible(const HloInstruction& instr) {
205   // Input fusion only handles non-elemental reduction and scatter operations.
206   return instr.IsFusible() &&
207          (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
208 }
209 
IsLoopFusible(const HloInstruction & instr)210 bool IsLoopFusible(const HloInstruction& instr) {
211   // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
212   // fusing.  We never generate kernels for unfused GTEs.  Instead, if an
213   // unfused GTE is an input to a kernel (including a fusion kernel), we
214   // compute the address of the GTE at the top of the kernel.  Often we know the
215   // address of the GTE result statically, so we can do this without chasing any
216   // pointers.
217   return instr.IsFusible() &&
218          ((instr.IsElementwise() && instr.operand_count() > 0) ||
219           instr.opcode() == HloOpcode::kBitcast ||
220           instr.opcode() == HloOpcode::kBroadcast ||
221           instr.opcode() == HloOpcode::kConcatenate ||
222           instr.opcode() == HloOpcode::kDynamicSlice ||
223           instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
224           (instr.opcode() == HloOpcode::kFusion &&
225            instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
226           instr.opcode() == HloOpcode::kGather ||
227           instr.opcode() == HloOpcode::kIota ||
228           instr.opcode() == HloOpcode::kPad ||
229           (instr.opcode() == HloOpcode::kReduce &&
230            !IsReductionFromOrToContiguousDimensions(instr) &&
231            !instr.shape().IsTuple()) ||  // TODO(b/129089333): Don't fuse
232                                          // variadic reductions.
233           instr.opcode() == HloOpcode::kReduceWindow ||
234           instr.opcode() == HloOpcode::kReshape ||
235           instr.opcode() == HloOpcode::kReverse ||
236           instr.opcode() == HloOpcode::kSlice ||
237           instr.opcode() == HloOpcode::kConstant ||
238           instr.opcode() == HloOpcode::kTranspose);
239 }
240 
IsProducerConsumerFusible(const HloInstruction & producer,const HloInstruction & consumer)241 bool IsProducerConsumerFusible(const HloInstruction& producer,
242                                const HloInstruction& consumer) {
243   if (!IsLoopFusible(producer)) {
244     VLOG(5) << "Producer " << producer.name() << " is not loop-fusible";
245     return false;
246   }
247 
248   if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) {
249     VLOG(5) << "Consumer " << consumer.name()
250             << "is not input-fusible and not loop-fusible";
251     return false;
252   }
253 
254   // Skip multiple output fusion. It's not yet supported.
255   if (producer.IsMultiOutputFusion()) {
256     VLOG(5) << "Producer " << producer.name()
257             << " is not fusible as it is a multi-output fusion";
258     return false;
259   }
260 
261   if (CreatesNestedLoop(producer, consumer)) {
262     VLOG(5) << "Fusing " << producer.name() << " into " << consumer.name()
263             << " creates nested loop";
264     return false;
265   }
266 
267   // Do not fuse into reduce input fusions if the resulting kernel would suffer
268   // from poor data locality (due to unfriendly input layouts).
269   if (IsInputFusibleReduction(consumer) &&
270       !LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
271     VLOG(5) << "Layout of " << producer.name()
272             << " is not fusion-friendly for consumer reduction "
273             << consumer.name();
274     return false;
275   }
276 
277   // Fuse scalar constants into loop fusion nodes. This reduces the number of
278   // parameters and makes matching scalar broadcasts easier.
279   //
280   // Don't fuse other constants: Unfused constants in GPU land can be
281   // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
282   // but fused constants are handled by shrared CPU/GPU code and always emitted
283   // in the IR/PTX.  The external constant representation makes for faster
284   // compiles and significantly smaller assembly code.
285   if (producer.opcode() == HloOpcode::kConstant &&
286       (!ShapeUtil::IsEffectiveScalar(producer.shape()) ||
287        consumer.opcode() != HloOpcode::kFusion)) {
288     VLOG(5) << "Not fusing constant " << producer.name() << " into "
289             << consumer.name();
290     return false;
291   }
292 
293   return true;
294 }
295 
IsProducerConsumerMultiOutputFusible(const HloInstruction & producer,const HloInstruction & consumer)296 bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer,
297                                           const HloInstruction& consumer) {
298   // Skip multiple output fusion. It's not yet supported.
299   if (producer.IsMultiOutputFusion()) {
300     return false;
301   }
302 
303   if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
304     return false;
305   }
306   if (CreatesNestedLoop(producer, consumer)) {
307     return false;
308   }
309   if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
310     return false;
311   }
312   if (!LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
313     return false;
314   }
315   return true;
316 }
317 
318 // Returns shared memory usage for a given instruction in bytes.
SharedMemoryUsage(const HloInstruction & instr)319 static int64 SharedMemoryUsage(const HloInstruction& instr) {
320   // For now we are only fusing reductions.
321   if (instr.opcode() == HloOpcode::kReduce &&
322       IsReductionFromOrToContiguousDimensions(instr)) {
323     ReductionDimensions reduction_info =
324         GetReductionKindAndContiguousComponents(instr);
325     int64 primitive_size =
326         ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type());
327     if (reduction_info.is_row_reduction) {
328       // __shared__[32] is used for row reduction.
329       return 32 * primitive_size;
330     } else {
331       // __shared__[2][32][33] cache is used for column reduction ("2" comes
332       // from potential x-tiling).
333       return 2 * 32 * 33 * primitive_size;
334     }
335   } else if (instr.opcode() == HloOpcode::kFusion) {
336     int64 sum = 0;
337     for (const HloInstruction* hlo :
338          instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
339       sum += SharedMemoryUsage(*hlo);
340     }
341     return sum;
342   }
343   // Other fused expressions for now don't need the shared memory budget.
344   return 0;
345 }
346 
347 // This function limits the maximum number of operands to a fusion, and the
348 // amount of shared memory which can be consumed by the fusion.
349 //
350 // There's a cap on how many parameters we can pass to a CUDA kernel, but
351 // exactly what that limit is hazy, as it depends on (among other things) how
352 // much GPU constant memory is in use for other purposes.
353 //
354 // Moreover, we don't even know at the point that we're running fusion how many
355 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
356 // assignment, where we will decide which of the fusion's operands live in XLA's
357 // big temp buffer versus in other allocations.
358 //
359 // As a heuristic, we simply cap the number of fusion operands plus outputs at
360 // kMaxOperandsAndOutputsPerFusion.  This puts an upper bound on the number of
361 // parameters to the kernel, working around the correctness problem.
362 //
363 // This limit is also often good for performance.  In a fusion with many
364 // operands, each GPU thread likely has to do a lot of work, and so possibly
365 // uses a lot of registers, thus limiting occupancy.
366 //
367 // If the fusion is a producer/consumer fusion and instr1 is the
368 // consumer and instr2 is the producer, set is_consumer_producer_fusion
369 // to true to enable more fusion.
FusionWouldBeTooLarge(const HloInstruction & instr1,const HloInstruction & instr2,bool is_consumer_producer_fusion)370 bool FusionWouldBeTooLarge(const HloInstruction& instr1,
371                            const HloInstruction& instr2,
372                            bool is_consumer_producer_fusion) {
373   if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) >
374       kSharedMemoryBudgetInBytes) {
375     VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString()
376             << " and " << instr2.ToString() << " would be over the budget of "
377             << kSharedMemoryBudgetInBytes << "B";
378     return true;
379   }
380 
381   // Compute the number of outputs of the (possibly multi-output) fusion node
382   // we're considering creating.
383   //
384   // This isn't precise; we may be off by one if
385   //  - We're creating a multi-output fusion out of two non-MOFs.  Creating a
386   //    MOF adds a new buffer, namely, the tuple buffer.
387   //  - We're merging two MOFs.  In this case, we should count the tuple buffer
388   //    only once.
389   //  - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
390   //    `a`.  In this case the result of `a` is not part of the output of the
391   //    fusion.
392   //
393   // But because this is a heuristic and our limit
394   // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a
395   // big difference), we ignore this small inaccuracy in favor of simplicity.
396   int64 num_output_buffers = ShapeUtil::SubshapeCount(instr1.shape()) +
397                              ShapeUtil::SubshapeCount(instr2.shape());
398 
399   // The new fusion will have no more operands and outputs than
400   //   producer_operands + consumer_operands - 1 + num_output_buffers
401   // (minus one because we may be fusing a producer->consumer edge between `a`
402   // and `b`).
403   //
404   // This fact may be enough to let us avoid having to compute the true total
405   // number of operands, which can be expensive.
406   if (instr1.operand_count() + instr2.operand_count() - 1 +
407           num_output_buffers <=
408       kMaxOperandsAndOutputsPerFusion) {
409     return false;
410   } else {
411     VLOG(5) << "Operand count of "
412             << "(" << instr1.ToString() << " ) = " << instr1.operand_count()
413             << " and ( " << instr2.ToString()
414             << " ) = " << instr2.operand_count()
415             << " and num_output_buffers = " << num_output_buffers
416             << " is bigger than the bound of "
417             << kMaxOperandsAndOutputsPerFusion;
418   }
419 
420   // Compute the precise number of operands to the new fusion.
421   absl::flat_hash_set<const HloInstruction*> operands(instr1.operands().begin(),
422                                                       instr1.operands().end());
423   operands.insert(instr2.operands().begin(), instr2.operands().end());
424   // If there's an edge between `a` and `b`, don't count it: We're fusing that
425   // producer -> consumer relationship.
426   operands.erase(&instr1);
427   operands.erase(&instr2);
428 
429   // If we generate the same numbers of inputs and outputs as
430   // before, it won't be bigger after fusion. So accept the fusion.
431   // As this is a consumer_producer fusion, this does not change the
432   // consumer numbers of output. So no need to check it.
433   if (is_consumer_producer_fusion &&
434       operands.size() <= instr1.operands().size()) {
435     return false;
436   }
437 
438   // Does the new fusion have more operands and outputs than the max?
439   return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
440 }
441 
CreatesNestedLoop(const HloInstruction & producer,const HloInstruction & consumer)442 bool CreatesNestedLoop(const HloInstruction& producer,
443                        const HloInstruction& consumer) {
444   // If producer does not have an instruction that codegens a loop then there is
445   // nothing to do.
446   auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
447     if (producer.opcode() != HloOpcode::kFusion) {
448       return IfFusedReadsElementsMultipleTimes(producer);
449     }
450     for (const auto& instr : producer.fused_instructions()) {
451       if (IfFusedReadsElementsMultipleTimes(*instr)) {
452         return true;
453       }
454     }
455     return false;
456   };
457   if (!producer_has_loop_codegen(producer)) {
458     return false;
459   }
460 
461   // If consumer is a non-fusion instruction then we have to check if it
462   // generates a loop.
463   if (consumer.opcode() != HloOpcode::kFusion) {
464     return IfFusedReadsElementsMultipleTimes(consumer);
465   }
466 
467   // If consumer is a fusion then we have to check if the output of producer is
468   // used directly or indirectly as an input to an HLO instruction that
469   // generates a loop, i.e. there is a path in the graph from an operand
470   // corresponding to the producer to an HLO instruction generating a loop in
471   // the consumer.
472   for (const HloInstruction* operand : consumer.operands()) {
473     if (operand != &producer) {
474       continue;
475     }
476 
477     const HloInstruction* root =
478         consumer.fused_instructions_computation()->parameter_instruction(
479             consumer.operand_index(operand));
480 
481     std::stack<const HloInstruction*> dfs;
482     dfs.push(root);
483     absl::flat_hash_set<const HloInstruction*> visited;
484     while (!dfs.empty()) {
485       const HloInstruction* cur = dfs.top();
486       dfs.pop();
487 
488       if (visited.contains(cur)) {
489         continue;
490       }
491       visited.insert(cur);
492 
493       if (IfFusedReadsElementsMultipleTimes(*cur)) {
494         return true;
495       }
496       for (const auto& user : cur->users()) {
497         if (visited.contains(user)) {
498           continue;
499         }
500         dfs.push(user);
501       }
502     }
503   }
504   return false;
505 }
506 
IsFusibleAsMultiOutputFusionRoot(const HloInstruction & instr)507 bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
508   // We can fuse reduces and loop fusions. Elementwise instructions can be fused
509   // with any other instruction.
510   // Note that scatter cannot be the root of a multi-output fusion because
511   // its emitter doesn't support it.
512 
513   return instr.IsFusible() &&
514          (IsInputFusibleReduction(instr) ||
515           instr.IsLoopFusion() ||  // TODO(b/130013493): Use IsLoopFusible here.
516           instr.IsElementwise());
517 }
518 
ChooseFusionKind(const HloInstruction &,const HloInstruction & consumer)519 HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
520                                             const HloInstruction& consumer) {
521   return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput
522                                   : HloInstruction::FusionKind::kLoop;
523 }
524 
IsConsumerTheOnlyNonRootUser(const HloInstruction & instr,const HloInstruction & consumer)525 bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
526                                   const HloInstruction& consumer) {
527   return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
528     if (user->opcode() == HloOpcode::kGetTupleElement) {
529       // Skip GTE.
530       return IsConsumerTheOnlyNonRootUser(*user, consumer);
531     }
532     if (user == &consumer) {
533       // `user` is `consumer`.
534       return true;
535     }
536     if (user == user->parent()->root_instruction()) {
537       // Consumed by ROOT.
538       return true;
539     }
540     return false;
541   });
542 }
543 
544 }  // namespace gpu
545 }  // namespace xla
546