1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17 
18 #include <algorithm>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/util/env_var.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 namespace {
32 
GetOutputsOfFusion(const HloInstruction & instr)33 absl::InlinedVector<HloInstruction*, 2> GetOutputsOfFusion(
34     const HloInstruction& instr) {
35   CHECK(instr.opcode() == HloOpcode::kFusion);
36   HloInstruction* root = instr.fused_expression_root();
37   if (root->opcode() != HloOpcode::kTuple) {
38     return {root};
39   } else {
40     return root->operands();
41   }
42 }
43 
44 // Returns the number of outputs of the fused computation.
GetOutputSizeOfFusion(const HloInstruction & instr)45 size_t GetOutputSizeOfFusion(const HloInstruction& instr) {
46   CHECK(instr.opcode() == HloOpcode::kFusion);
47   const HloInstruction* root = instr.fused_expression_root();
48   if (root->opcode() != HloOpcode::kTuple) {
49     return 1;
50   } else {
51     return ShapeUtil::TupleElementCount(root->shape());
52   }
53 }
54 
GetUniqueOutputTypeOfFusion(const HloInstruction & instr)55 PrimitiveType GetUniqueOutputTypeOfFusion(const HloInstruction& instr) {
56   auto outputs = GetOutputsOfFusion(instr);
57   CHECK(!outputs.empty());
58   PrimitiveType first_output_type = outputs[0]->shape().element_type();
59   for (size_t i = 1; i < outputs.size(); ++i) {
60     PrimitiveType cur_output_type = outputs[i]->shape().element_type();
61     CHECK(first_output_type == cur_output_type)
62         << "Output types are expected to be unique, but see "
63         << PrimitiveType_Name(first_output_type) << " and "
64         << PrimitiveType_Name(cur_output_type);
65   }
66 
67   return first_output_type;
68 }
69 
70 class HorizontalLoopFusionImpl {
71  public:
HorizontalLoopFusionImpl(HloComputation * computation)72   explicit HorizontalLoopFusionImpl(HloComputation* computation)
73       : computation_(computation) {}
74 
~HorizontalLoopFusionImpl()75   ~HorizontalLoopFusionImpl() {}
76 
77   StatusOr<bool> Run();
78 
79  private:
80   Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs);
81 
82   // Horizontally fuses `fused_fusion_instrs`. It is required that each of
83   // `fused_fusion_instrs` is a kLoop fusion. Also, we require their numbers of
84   // outputs to be the same, so that each output will be fused/concatenated with
85   // the same number of outputs from other fused fusion instrs. Then, all the
86   // fused outputs still have the same shapes for kernel generation.
87   //
88   // Returns the fused computation in `uniq_computation` and the operands that
89   // are used by `uniq_computation`.
90   Status CreateFusedComputation(
91       absl::Span<HloInstruction*> fused_fusion_instrs,
92       std::unique_ptr<HloComputation>* uniq_computation,
93       std::vector<HloInstruction*>* bound_operands);
94 
95   // FusionCandidates collects profitable candidates for a given consumer
96   // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
97   // acquire the next set of fusion candidates based on some heuristics.
98   class FusionCandidates {
99    public:
FusionCandidates(HloInstruction * consumer)100     explicit FusionCandidates(HloInstruction* consumer)
101         : fusion_instrs_(), pos_(0) {
102       Initialize(consumer);
103     }
104 
105     // Gets a span of fusions to be fused.
106     absl::Span<HloInstruction*> GetNextSpanOfFusions();
107 
108    private:
109     void Initialize(HloInstruction*);
110 
111     std::vector<HloInstruction*> fusion_instrs_;
112     // `pos_` points to the start position of the next span.
113     size_t pos_;
114   };
115 
116   HloComputation* computation_;
117 };  // HorizontalLoopFusionImpl
118 
IsFusionSupported(const HloInstruction & instr)119 bool IsFusionSupported(const HloInstruction& instr) {
120   // Support only kLoop fusion now.
121   if (!instr.IsLoopFusion()) {
122     return false;
123   }
124 
125   // Cannot support fusion who has multiple output types, because the
126   // concatenate (inserted for horizontal fusion) requires the same type
127   // for all of its operands.
128   auto outputs = GetOutputsOfFusion(instr);
129   CHECK(!outputs.empty());
130   const HloInstruction* first_output = outputs[0];
131   for (size_t i = 1; i < outputs.size(); ++i) {
132     if (first_output->shape().element_type() !=
133         outputs[i]->shape().element_type()) {
134       return false;
135     }
136   }
137 
138   return true;
139 }
140 
141 // Returns whether `instr` is a profitable candidate to be horizontally fused.
142 // Since the primary benefit of horizontal fusion comes from reducing the
143 // kernel launch overhead, we want to exclude the instructions with
144 // insignificant kernel launch overhead. In other words, we exclude instructions
145 // if their computation latencies are longer than launch latencies. We estimate
146 // the computation latency of a given instruction by its shapes and the
147 // instruction count in its fused computation. We roughly observe that if a
148 // fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
149 // instructions than `kInstrCountThreshold`, it is launch-latency-bound and
150 // profitable by horizontal fusion.
IsProfitableFusionCandidate(const HloInstruction & instr)151 bool IsProfitableFusionCandidate(const HloInstruction& instr) {
152   CHECK(instr.opcode() == HloOpcode::kFusion);
153   constexpr int64 kShapeThreshold = 128 * 2048;
154   constexpr int64 kInstrCountThreshold = 30;
155   auto root = instr.fused_expression_root();
156 
157   // Too large shapes are not easily profitable.
158   if (root->opcode() == HloOpcode::kTuple) {
159     // Since all output shapes are the same, use the first shape as the
160     // representative.
161     auto shape = root->operand(0)->shape();
162     if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
163       return false;
164     }
165   } else {
166     auto shape = root->shape();
167     if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
168       return false;
169     }
170   }
171 
172   // Having too many instructions is not easily profitable.
173   if (instr.fused_instruction_count() > kInstrCountThreshold) {
174     return false;
175   }
176 
177   // We can emit DUS in-place, horizontally fusing it makes the emitter no
178   // longer recognize that it can be done in-place. This creates much slower
179   // code. This restriction could be lifted if buffer assignment would recognize
180   // that the DUS can be done in-place even inside of a horizontal fusion.
181   if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
182     return false;
183   }
184 
185   return true;
186 }
187 
188 // Returns whether `fusion_instr` has only row-major layouts.
189 // The horizontal fusion excludes computations with non-row-major layouts,
190 // because fusing computations with different layouts can result in uncoalesced
191 // memory accesses and cause great performance overhead.
HasOnlyRowMajorLayout(const HloInstruction & fusion_instr)192 bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
193   CHECK(fusion_instr.opcode() == HloOpcode::kFusion);
194   auto instrs = fusion_instr.fused_instructions_computation()->instructions();
195   for (auto instr : instrs) {
196     if (instr->shape().layout().format() != DENSE) {
197       continue;
198     }
199     if (!LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) {
200       return false;
201     }
202   }
203   return true;
204 }
205 
Initialize(HloInstruction * consumer)206 void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
207     HloInstruction* consumer) {
208   // First, find out all fusion instructions. We will filter out
209   // unsupported/non-profitable cases below.
210   absl::flat_hash_set<HloInstruction*> fusion_instrs;
211   for (auto opnd : consumer->operands()) {
212     auto predecessor = opnd->LatestNonGteAncestor();
213     if (predecessor->opcode() == HloOpcode::kFusion) {
214       fusion_instrs.insert(predecessor);
215     }
216   }
217 
218   for (auto instr : fusion_instrs) {
219     if (!IsFusionSupported(*instr)) {
220       VLOG(2) << "Reject unsupported fusion instr " << instr->ToString();
221       continue;
222     } else if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
223       VLOG(2) << "Reject maybe illegal instr " << instr->ToString()
224               << "; including it may create cycles in HLO.";
225       continue;
226     } else if (!IsProfitableFusionCandidate(*instr)) {
227       VLOG(2) << "Reject may-not-be profitable fusion instr "
228               << instr->ToString();
229       continue;
230     } else if (!HasOnlyRowMajorLayout(*instr)) {
231       VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
232       continue;
233     } else {
234       VLOG(2) << "Find a fusion candidate " << instr->ToString();
235       fusion_instrs_.push_back(instr);
236     }
237   }
238 
239   // Sort `fusion_instrs` according to output types, the number of outputs,
240   // and instruction counts, because we only fuse instructions with the same
241   // number/type of outputs and whose computations have the same instruction
242   // count.
243   std::sort(
244       fusion_instrs_.begin(), fusion_instrs_.end(),
245       [&](const HloInstruction* a, const HloInstruction* b) {
246         if (GetUniqueOutputTypeOfFusion(*a) !=
247             GetUniqueOutputTypeOfFusion(*b)) {
248           return GetUniqueOutputTypeOfFusion(*a) <
249                  GetUniqueOutputTypeOfFusion(*b);
250         } else if (GetOutputSizeOfFusion(*a) != GetOutputSizeOfFusion(*b)) {
251           return GetOutputSizeOfFusion(*a) < GetOutputSizeOfFusion(*b);
252         } else {
253           return a->fused_instruction_count() < b->fused_instruction_count();
254         }
255       });
256 }
257 
258 // Gets a next span of fusion instructions to be fused.
259 absl::Span<HloInstruction*>
GetNextSpanOfFusions()260 HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
261   if (pos_ >= fusion_instrs_.size()) {
262     return absl::Span<HloInstruction*>();
263   }
264 
265   // Fusing too many computations at a time may not be easily profitable and
266   // may increase compile time due to large kernels. Set a limit to it.
267   constexpr int64 kMaxFusionBatchSize = 32;
268   // CUDA has a parameter size limit of ~4k bytes.
269   constexpr int64 kMaxCudaParamSize = 4000;
270   size_t accum_io_size = 0;
271   auto reach_max_fusion_batch_size = [&](size_t left, size_t right) -> bool {
272     if (right - left >= kMaxFusionBatchSize) {
273       return true;
274     }
275 
276     accum_io_size += fusion_instrs_.at(right)->fused_parameters().size() +
277                      GetOutputSizeOfFusion(*fusion_instrs_.at(right));
278 
279     if (accum_io_size * 8 >= kMaxCudaParamSize) {
280       return true;
281     }
282 
283     return false;
284   };
285 
286   size_t left = pos_;
287   size_t right = pos_ + 1;
288   size_t first_output_size = GetOutputSizeOfFusion(*fusion_instrs_[left]);
289   PrimitiveType first_output_type =
290       GetUniqueOutputTypeOfFusion(*fusion_instrs_[left]);
291   for (; right < fusion_instrs_.size(); ++right) {
292     PrimitiveType cur_output_type =
293         GetUniqueOutputTypeOfFusion(*fusion_instrs_[right]);
294     if (first_output_type != cur_output_type) {
295       // Cannot fuse computations who have multiple output types.
296       break;
297     } else if (first_output_size !=
298                GetOutputSizeOfFusion(*fusion_instrs_[right])) {
299       // Cannot fuse computations who have different numbers of outputs.
300       break;
301     } else if (fusion_instrs_[left]->fused_instruction_count() !=
302                fusion_instrs_[right]->fused_instruction_count()) {
303       // Do not fuse computations of different instruction counts as it may
304       // introduce control divergence. This is a very simple heuristic to avoid
305       // fusing computations with too much discrepancy and we may improve it
306       // when the needs arise.
307       break;
308     } else if (reach_max_fusion_batch_size(left, right)) {
309       // Hit max fusion batch size.
310       break;
311     }
312   }
313 
314   pos_ = right;
315   return absl::MakeSpan(fusion_instrs_).subspan(left, right - left);
316 }
317 
CreateFusedComputation(absl::Span<HloInstruction * > fused_fusion_instrs,std::unique_ptr<HloComputation> * uniq_computation,std::vector<HloInstruction * > * bound_operands)318 Status HorizontalLoopFusionImpl::CreateFusedComputation(
319     absl::Span<HloInstruction*> fused_fusion_instrs,
320     std::unique_ptr<HloComputation>* uniq_computation,
321     std::vector<HloInstruction*>* bound_operands) {
322   // First, build a computation with only params.
323   HloComputation::Builder b("horizontally_fused_computation");
324   size_t fused_comp_param_id = 0;
325   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
326     auto old_params = fused_fusion_instrs[i]->fused_parameters();
327     for (size_t j = 0; j < old_params.size(); ++j) {
328       auto bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
329       // in a form of param_i_j
330       b.AddInstruction(HloInstruction::CreateParameter(
331           fused_comp_param_id++, bound_opnd->shape(),
332           absl::StrCat("param_", i, "_", j)));
333       bound_operands->push_back(bound_opnd);
334     }
335   }
336   // Always create a dummy tuple instruction to serve as the root of the
337   // computation, as the existence of a root instruction is required by the
338   // HloComputation. The real root instruction will replace it below.
339   auto dummy_root = b.AddInstruction(
340       HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
341   *uniq_computation = b.Build(dummy_root);
342   auto* comp = uniq_computation->get();
343 
344   // Preparing clone_map, which maps old operand to new operand.
345   absl::flat_hash_map<HloInstruction*, HloInstruction*> clone_map;
346   size_t new_param_id = 0;
347   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
348     auto old_params = fused_fusion_instrs[i]->fused_parameters();
349     for (size_t j = 0; j < old_params.size(); ++j) {
350       auto old_param = old_params[j];
351       auto new_param = comp->parameter_instruction(new_param_id++);
352       clone_map.insert({old_param, new_param});
353     }
354   }
355 
356   // Clone every fused computation.
357   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
358     auto def_to_use_order = fused_fusion_instrs[i]
359                                 ->fused_instructions_computation()
360                                 ->MakeInstructionPostOrder();
361     for (auto old_instr : def_to_use_order) {
362       if (old_instr->opcode() == HloOpcode::kParameter) {
363         // Parameters have been created.
364         continue;
365       }
366       std::vector<HloInstruction*> new_opnds;
367       for (auto old_opnd : old_instr->operands()) {
368         CHECK(clone_map.find(old_opnd) != clone_map.end());
369         new_opnds.push_back(clone_map[old_opnd]);
370       }
371       auto new_instr = comp->AddInstruction(
372           old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
373       clone_map.insert({old_instr, new_instr});
374     }
375   }
376 
377   std::vector<HloInstruction*> concated_outputs;
378   // Since we require each fusion to have the same number of outputs, we can
379   // simply use the first fusion as the representative for output size.
380   size_t fused_instr_output_size =
381       GetOutputSizeOfFusion(*fused_fusion_instrs[0]);
382   for (size_t i = 0; i < fused_instr_output_size; ++i) {
383     std::vector<HloInstruction*> reshapes(fused_fusion_instrs.size());
384     for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
385       auto old_output = GetOutputsOfFusion(*fused_fusion_instrs[j])[i];
386       auto new_output = clone_map[old_output];
387       TF_ASSIGN_OR_RETURN(
388           reshapes[j],
389           MakeReshapeHlo(ShapeUtil::MakeShapeWithLayout(
390                              new_output->shape().element_type(),
391                              {ShapeUtil::ElementsIn(new_output->shape())},
392                              /*minor_to_major=*/std::vector<int64>(1, 0)),
393                          new_output));
394     }
395     TF_ASSIGN_OR_RETURN(auto concated_output, MakeConcatHlo(reshapes, 0));
396     concated_outputs.push_back(concated_output);
397   }
398 
399   // Make slices of outputs.
400   std::vector<HloInstruction*> output_slices(concated_outputs.size() *
401                                              fused_fusion_instrs.size());
402   for (size_t i = 0; i < concated_outputs.size(); ++i) {
403     auto concated_output = concated_outputs[i];
404     int64 slice_start = 0;
405     // Create a slice per fused computation.
406     for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
407       auto old_output = GetOutputsOfFusion(*fused_fusion_instrs[j])[i];
408       auto shape = old_output->shape();
409       int64 slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
410       TF_ASSIGN_OR_RETURN(
411           output_slices[concated_outputs.size() * j + i],
412           MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
413                        /*strides=*/{1}));
414       slice_start = slice_limit;
415     }
416   }
417 
418   // Make a tuple of output_slices.
419   auto tuple = comp->AddInstruction(HloInstruction::CreateTuple(output_slices));
420   comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
421   TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
422 
423   return Status::OK();
424 }
425 
Fuse(absl::Span<HloInstruction * > fused_fusion_instrs)426 Status HorizontalLoopFusionImpl::Fuse(
427     absl::Span<HloInstruction*> fused_fusion_instrs) {
428   // Fuse fused_fusion_instrs and replace them with the new fused computation.
429   std::unique_ptr<HloComputation> uniq_computation;
430   std::vector<HloInstruction*> bound_operands;
431   TF_RETURN_IF_ERROR(CreateFusedComputation(
432       fused_fusion_instrs, &uniq_computation, &bound_operands));
433   auto fused_comp = computation_->parent()->AddEmbeddedComputation(
434       std::move(uniq_computation));
435   auto hori_fusion_instr =
436       computation_->AddInstruction(HloInstruction::CreateFusion(
437           fused_comp->root_instruction()->shape(),
438           HloInstruction::FusionKind::kInput, bound_operands, fused_comp));
439   fused_comp->SetFusionInstruction(hori_fusion_instr);
440 
441   // Insert bitcasts and replace corresponding users. Note that we do not insert
442   // the bitcasts in the fused computation as it does not fit into the slice
443   // input fusion pattern. However, inserting bitcasts outside the fused
444   // computation creates no performance cost.
445   size_t total_output_id = 0;
446   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
447     std::vector<HloInstruction*> bitcasts;
448     auto fused_instr = fused_fusion_instrs[i];
449     auto num_outputs = GetOutputSizeOfFusion(*fused_instr);
450     for (size_t j = 0; j < num_outputs; ++j) {
451       auto output = GetOutputsOfFusion(*fused_instr)[j];
452       TF_ASSIGN_OR_RETURN(auto gep, MakeGetTupleElementHlo(hori_fusion_instr,
453                                                            total_output_id++));
454       bitcasts.push_back(computation_->AddInstruction(
455           HloInstruction::CreateBitcast(output->shape(), gep)));
456     }
457     auto bitcast_or_tuple = (bitcasts.size() == 1)
458                                 ? bitcasts.at(0)
459                                 : computation_->AddInstruction(
460                                       HloInstruction::CreateTuple(bitcasts));
461     TF_RETURN_IF_ERROR(
462         computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
463   }
464 
465   return Status::OK();
466 }
467 
Run()468 StatusOr<bool> HorizontalLoopFusionImpl::Run() {
469   bool changed = false;
470   XLA_VLOG_LINES(3, computation_->ToString());
471 
472   // Using def-to-use order is sound since we do not modify users.
473   std::vector<HloInstruction*> def_to_use_order =
474       computation_->MakeInstructionPostOrder();
475   for (size_t i = 0; i < def_to_use_order.size(); ++i) {
476     auto consumer = def_to_use_order[i];
477     HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer);
478     while (true) {
479       auto fusions = fusion_candidates.GetNextSpanOfFusions();
480       if (fusions.empty()) {
481         break;
482       } else if (fusions.size() == 1) {
483         // Skip; there is just one fused_instr.
484         continue;
485       }
486 
487       changed = true;
488       TF_RETURN_IF_ERROR(Fuse(fusions));
489     }
490   }
491 
492   return changed;
493 }
494 
495 }  // namespace
496 
RunOnComputation(HloComputation * computation)497 StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
498     HloComputation* computation) {
499   HorizontalLoopFusionImpl horizontal_fusion_impl(computation);
500   return horizontal_fusion_impl.Run();
501 }
502 
Run(HloModule * module)503 StatusOr<bool> GpuHorizontalLoopFusion::Run(HloModule* module) {
504   bool changed = false;
505   VLOG(2) << "Run horizontal fusion.";
506   for (auto* comp : module->MakeNonfusionComputations()) {
507     TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
508   }
509 
510   return changed;
511 }
512 
513 }  // namespace gpu
514 }  // namespace xla
515