1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_query.h"
35 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
36 #include "tensorflow/compiler/xla/service/shape_inference.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace xla {
44 namespace {
45 
46 // Combines the elements of to_combine into a single AllReduce op. All
47 // entries in to_combine must be AllReduce ops with exactly one operand
48 // and the same reduction operation.
CombineAllReduces(absl::Span<HloInstruction * const> to_combine)49 Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
50   if (to_combine.size() < 2) {
51     return Status::OK();
52   }
53   VLOG(1) << "Combined " << to_combine.size() << " CRS ops";
54 
55   HloComputation& computation = *to_combine.back()->parent();
56   HloComputation* reduction = to_combine[0]->to_apply();
57   const HloOpcode type = reduction->root_instruction()->opcode();
58 
59   // Create a single bigger AllReduce of the operands of the smaller
60   // AllReduces.
61   std::vector<HloInstruction*> operands;
62   std::vector<Shape> operand_shapes;
63   VLOG(1) << "Combining set";
64   for (HloInstruction* hlo : to_combine) {
65     VLOG(1) << "Set element: " << hlo->ToString();
66     TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllReduce);
67     TF_RET_CHECK(hlo->operands().size() == 1);
68     TF_RET_CHECK(hlo->to_apply() == reduction ||
69                  (hlo->to_apply()->instruction_count() == 3 &&
70                   hlo->to_apply()->num_parameters() == 2 &&
71                   hlo->to_apply()->root_instruction()->opcode() == type));
72     TF_RET_CHECK(hlo->shape().IsArray());
73     for (HloInstruction* operand : hlo->operands()) {
74       operands.push_back(operand);
75       operand_shapes.push_back(operand->shape());
76     }
77   }
78 
79   HloInstruction* combined;
80   // AllReduce ops with more than one operand produce a tuple.
81   TF_RET_CHECK(operands.size() >= 2);
82   combined = computation.AddInstruction(HloInstruction::CreateAllReduce(
83       ShapeUtil::MakeTupleShape(operand_shapes), operands, reduction,
84       to_combine.front()->replica_groups(),
85       /*constrain_layout=*/false, to_combine.front()->channel_id(),
86       Cast<HloAllReduceInstruction>(to_combine.front())
87           ->use_global_device_ids()));
88 
89   // We have to propagate the sharding manually because Domain instructions are
90   // not guaranteed to preserve it for side effecting instructions.
91   if (to_combine.front()->has_sharding()) {
92     combined->set_sharding(to_combine.front()->sharding());
93   }
94   VLOG(1) << "Replacing with : " << combined->ToString();
95 
96   // Replace all the smaller AllReduces with elements of the tuple output
97   // of the single bigger AllReduce.
98   for (int64 i = 0; i < to_combine.size(); ++i) {
99     auto replace_with = HloInstruction::CreateGetTupleElement(
100         to_combine[i]->shape(), combined, i);
101     TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
102         to_combine[i], std::move(replace_with)));
103   }
104   return Status::OK();
105 }
106 
107 struct GroupKey {
GroupKeyxla::__anonc0249c850111::GroupKey108   GroupKey(const HloInstruction* hlo, const HloDomainMap& domain_map)
109       : opcode(hlo->to_apply()->root_instruction()->opcode()),
110         accum_type(hlo->to_apply()->root_instruction()->shape().element_type()),
111         domain_id(domain_map.GetDomainMetadataId(hlo)),
112         is_cross_shard(hlo->channel_id().has_value()),
113         use_global_device_ids(
114             Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids()),
115         replica_groups(hlo->replica_groups()) {}
116 
operator <xla::__anonc0249c850111::GroupKey117   bool operator<(const GroupKey& other) const {
118     if (opcode != other.opcode) {
119       return opcode < other.opcode;
120     }
121     if (accum_type != other.accum_type) {
122       return accum_type < other.accum_type;
123     }
124     if (domain_id != other.domain_id) {
125       return domain_id < other.domain_id;
126     }
127     if (is_cross_shard != other.is_cross_shard) {
128       return is_cross_shard < other.is_cross_shard;
129     }
130     if (use_global_device_ids != other.use_global_device_ids) {
131       return use_global_device_ids < other.use_global_device_ids;
132     }
133     if (replica_groups.size() != other.replica_groups.size()) {
134       return replica_groups.size() < other.replica_groups.size();
135     }
136     for (int64 i = 0; i < replica_groups.size(); ++i) {
137       const auto& rg = replica_groups[i];
138       const auto& org = other.replica_groups[i];
139       if (rg.replica_ids_size() != org.replica_ids_size()) {
140         return rg.replica_ids_size() < org.replica_ids_size();
141       }
142       for (int64 j = 0; j < rg.replica_ids_size(); ++j) {
143         if (rg.replica_ids(j) != org.replica_ids(j)) {
144           return rg.replica_ids(j) < org.replica_ids(j);
145         }
146       }
147     }
148     return false;
149   }
150 
151   HloOpcode opcode;
152   PrimitiveType accum_type;
153   int64 domain_id;
154   bool is_cross_shard;
155   bool use_global_device_ids;
156   std::vector<ReplicaGroup> replica_groups;
157 };
158 
159 // Group AllReduce instructions by the reduction types, e.g., add, min,
160 // max, replica groups and domain. For cross-module all reduce instructions
161 // we group them by the set of domains they are reducing across.
162 //
163 // Note that the shape of the reduction computation is not included in the
164 // reduction types, e.g.: "f32[] add" and "bf16[] add" will be the same type. We
165 // need to disallow combining CRS instructions with different domain metadata as
166 // well as that could end up short-cutting two or more different domains.
167 //
168 // In each group, the instructions should be in post order. We will then iterate
169 // each group and try to combine them, so to prevent non-determinism, we use
170 // std::map here.
171 //
172 // The return value is a list of groups where every group contains a list of
173 // all-reduce instruction sets in topological order and with a deterministic
174 // order within the set. Additionally due to the above constraints every all
175 // reduce set within a group will contain the same number of elements
176 // and every instruction within an all reduce set will have the same
177 // all-reduce-id (if specified) and thus shape (all reduce sets without an
178 // all-reduce-id will have a single instruction).
179 using InstructionGroups =
180     std::vector<std::vector<std::vector<HloInstruction*>>>;
CreateComputationGroups(HloComputation * computation)181 StatusOr<InstructionGroups> CreateComputationGroups(
182     HloComputation* computation) {
183   TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
184 
185   // Group instructions by opcode, domain id and replica group.
186   std::map<GroupKey, std::vector<HloInstruction*>> opcode_groups;
187   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
188     if (instruction->opcode() != HloOpcode::kAllReduce) {
189       continue;
190     }
191     if (instruction->to_apply()->instruction_count() != 3 ||
192         instruction->to_apply()->num_parameters() != 2) {
193       VLOG(1) << "Skipping due to non-trivial reduction function.";
194       continue;
195     }
196     opcode_groups[GroupKey(instruction, *domain_map)].push_back(instruction);
197   }
198 
199   // Generate a unique all-reduce-id for instructions without one by negating
200   // the unique id of the hlo. This way we can treat cross module and normal CRS
201   // instructions uniformly.
202   auto channel_id = [](const HloInstruction* all_reduce) {
203     return all_reduce->IsCrossModuleAllReduce()
204                ? all_reduce->channel_id().value()
205                : -1 * all_reduce->unique_id();
206   };
207 
208   // Group instructions by all-reduce id with instructions for an all-reduce id
209   // is listed along their group id and the (group id, instruction) pairs are
210   // sorted by group id in the vector.
211   std::map<int64, std::vector<std::pair<int64, HloInstruction*>>>
212       all_reduce_sets;
213   int64 group_id = 0;
214   for (auto& domain_groups : opcode_groups) {
215     for (HloInstruction* hlo : domain_groups.second) {
216       all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo);
217     }
218     ++group_id;
219   }
220 
221   // Group instructions by participating group ids. Instructions within a group
222   // are sorted by topological order and instructions within an all reduce group
223   // is still sorted by group id.
224   std::map<std::vector<int64>, std::vector<std::vector<HloInstruction*>>>
225       all_reduce_group_map;
226   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
227     if (instruction->opcode() != HloOpcode::kAllReduce) {
228       continue;
229     }
230     if (instruction->to_apply()->instruction_count() != 3 ||
231         instruction->to_apply()->num_parameters() != 2) {
232       VLOG(1) << "Skipping due to non-trivial reduction function.";
233       continue;
234     }
235 
236     int64 arid = channel_id(instruction);
237     if (all_reduce_sets.count(arid) == 0) {
238       // Already processed.
239       continue;
240     }
241 
242     std::vector<int64> group_ids;
243     std::vector<HloInstruction*> instructions;
244     for (const auto& hlo : all_reduce_sets[arid]) {
245       group_ids.push_back(hlo.first);
246       instructions.push_back(hlo.second);
247     }
248     all_reduce_group_map[group_ids].push_back(std::move(instructions));
249     all_reduce_sets.erase(arid);
250   }
251   CHECK(all_reduce_sets.empty());
252 
253   InstructionGroups groups;
254   for (const auto& all_reduce_group : all_reduce_group_map) {
255     groups.push_back(all_reduce_group.second);
256   }
257   return std::move(groups);
258 }
259 
260 }  // namespace
261 
AllReduceCombiner(int64 combine_threshold_in_bytes,int64 combine_threshold_count)262 AllReduceCombiner::AllReduceCombiner(int64 combine_threshold_in_bytes,
263                                      int64 combine_threshold_count)
264     : combine_threshold_in_bytes_(combine_threshold_in_bytes),
265       combine_threshold_count_(combine_threshold_count) {}
266 
Run(HloModule * module)267 StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
268   VLOG(1) << "Running AllReduceCombiner with threshold of "
269           << combine_threshold_in_bytes_ << " bytes";
270 
271   if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
272     VLOG(1) << "Skip AllReduceCombiner because the threshold is zero";
273     return false;
274   }
275 
276   if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
277     VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce "
278                "with constrained layouts";
279     return false;
280   }
281 
282   bool changed = false;
283   for (HloComputation* computation : module->MakeNonfusionComputations()) {
284     TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation));
285     for (auto group : groups) {
286       // Recompute reachability after every combine group because we can't
287       // maintain a cross group topolgical order to be able to rely on the
288       // transitive dependencies to detect cycles.
289       auto reachability = HloReachabilityMap::Build(computation);
290 
291       // Create a map to be able to find an instruction group based on the first
292       // instruction in the group. It will be used during the post order
293       // iteration to be able to process full groups at a time. Doing it only
294       // for one instruction in every group will be sufficient because all
295       // instruction have to schedule at the same time due to cross core
296       // dependencies.
297       absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>*>
298           group_map;
299       for (auto& instruction : group) {
300         group_map[instruction.front()] = &instruction;
301       }
302 
303       // Collect sets of AllReduce instructions to combine.
304       std::vector<std::vector<std::vector<HloInstruction*>>> combine_sets(1);
305       int64 current_size_in_bytes = 0;
306       int64 current_operand_count = 0;
307 
308       // Iterate all instructions in post order and skip the ones not in the
309       // current group. We have to create a new post order iteration for every
310       // group because merging instructions in the previous group can made the
311       // original post order no longer hold.
312       // This will make it likely that we won't increase memory pressure much
313       // above combine_threshold_in_bytes, since two AllReduces that are
314       // near in post order are most likely, but not for sure, also near in
315       // scheduled order.
316       //
317       // TODO(b/70235266): This should usually be fine, but it's probably
318       // possible to construct some case where the memory usage increases beyond
319       // the threshold due to reordering of the instructions in scheduling. If
320       // this ever comes up as a real problem, it would be nice to implement
321       // safeguards so that that cannot possibly happen.
322       for (const HloInstruction* inst :
323            computation->MakeInstructionPostOrder()) {
324         auto it = group_map.find(inst);
325         if (it == group_map.end()) {
326           // Instruction belongs to a different group.
327           continue;
328         }
329         const auto& instructions = *it->second;
330 
331         VLOG(1) << "Considering HLO " << instructions.front()->ToString()
332                 << " with current set size of " << current_size_in_bytes
333                 << " and current operand count of " << current_operand_count;
334 
335         // We do not handle AllReduce ops that do not have exactly 1
336         // operand since that is simpler and this pass is the only way to
337         // generate such ops and it should rarely be important to consider the
338         // same ops again.
339         if (instructions.front()->operands().size() != 1) {
340           VLOG(1) << "Skipping due to "
341                   << instructions.front()->operands().size() << " operands";
342           continue;
343         }
344 
345         int64 size_in_bytes;
346         TF_RET_CHECK(instructions.front()->shape().IsArray());
347         size_in_bytes = ShapeUtil::ByteSizeOf(instructions.front()->shape());
348 
349         if (size_in_bytes > combine_threshold_in_bytes_) {
350           VLOG(1) << "Skipping due to size " << size_in_bytes
351                   << " above threshold";
352           // If the instruction is greather than the threshold, then we can
353           // never combine it with anything.
354           continue;
355         }
356 
357         // If the current set is dependent on the instruction, then create a new
358         // one to avoid the dependency. We move on from the current set instead
359         // of ignoring the instruction since otherwise a single AllReduce
360         // instruction that all the other ones depend on (such as one on the
361         // forward pass of a model) could disable this optimization entirely.
362         TF_RET_CHECK(!combine_sets.empty());
363         for (const auto& previous : combine_sets.back()) {
364           // The reachability information does not reflect the planned
365           // combination from combine_sets. We cannot just bring it up to date
366           // cheaply since HloReachabilityMap does not track reachability
367           // updates transitively and doing it directly is expensive. However,
368           // leaving it stale has no effect on the reachability queries that we
369           // are doing here because we are considering the ops in a topological
370           // order, so we can just leave it stale.
371           //
372           // Proof: Suppose A is the instruction we are looking to combine and B
373           // is an element of the current combine set that we are looking to
374           // combine A into.
375           //
376           // First of all, we check that all elements in each set do not depend
377           // on each other, so combining the *current* combine set cannot create
378           // new dependencies between A and B. It remains to prove that
379           // combining the prior combine sets also cannot create a dependency
380           // between A and B.
381           //
382           // Assume to get a contradiction that there are two AllReduce
383           // ops C and D in combine_sets that will be combined and that A and B
384           // are not connected now but that they will be after combining C and
385           // D. Then there exist paths in the dependency graph such that one of
386           // these cases is true:
387           //
388           //   A -> ... -> C and D -> ... -> B
389           //   A -> ... -> D and C -> ... -> B
390           //   B -> ... -> C and D -> ... -> A
391           //   B -> ... -> D and C -> ... -> A
392           //
393           // None of these cases are possible because we are visiting the nodes
394           // in a topological order, so C and D cannot be in-between A and B.
395           // That is a contradiction, so combining the prior combine sets also
396           // cannot create a dependency between A and B.
397           bool new_set = false;
398           for (int64 i = 0; i < instructions.size(); ++i) {
399             if (reachability->IsReachable(previous[i], instructions[i])) {
400               VLOG(1) << "Starting new set due to dependency between "
401                       << previous[i]->ToString() << " AND "
402                       << instructions[i]->ToString();
403               new_set = true;
404               break;
405             }
406           }
407           if (new_set) {
408             combine_sets.emplace_back();
409             current_size_in_bytes = 0;
410             current_operand_count = 0;
411             break;
412           }
413         }
414 
415         if (current_size_in_bytes + size_in_bytes >
416                 combine_threshold_in_bytes_ ||
417             current_operand_count + 1 > combine_threshold_count_) {
418           VLOG(1) << "The instruction cannot be entered into the set due "
419                      "to the combined size being too large.";
420           // In this case we cannot include the instruction into the current set
421           // since then it would grow beyond the threshold. The set of
422           // instructions to carry forward will either be the current set or the
423           // instruction by itself, whichever is smaller, since that maximizes
424           // the chance of being able to combine with the next instruction.
425           if (size_in_bytes > current_size_in_bytes) {
426             VLOG(1) << "Skipping as the instruction is larger than the set.";
427             continue;  // keep the current set
428           }
429           VLOG(1)
430               << "Resetting the set as the set is larger than the instruction.";
431           combine_sets.emplace_back();
432           current_size_in_bytes = 0;
433           current_operand_count = 0;
434         }
435 
436         VLOG(1) << "Adding instruction to set.";
437         combine_sets.back().push_back(instructions);
438         current_size_in_bytes += size_in_bytes;
439         current_operand_count += 1;
440         TF_RET_CHECK(current_size_in_bytes <= combine_threshold_in_bytes_);
441         TF_RET_CHECK(current_operand_count <= combine_threshold_count_);
442       }
443       VLOG(1) << "Done constructing sets. Final set size is "
444               << current_size_in_bytes << " bytes and " << current_operand_count
445               << " operands";
446 
447       // Combine the collected sets of AllReduce instructions.
448       for (const auto& combine_set : combine_sets) {
449         if (combine_set.size() >= 2) {
450           changed = true;
451           for (int64 i = 0; i < combine_set.front().size(); ++i) {
452             std::vector<HloInstruction*> to_combine;
453             to_combine.reserve(combine_set.size());
454             for (const auto& c : combine_set) {
455               to_combine.push_back(c[i]);
456             }
457             TF_RETURN_IF_ERROR(CombineAllReduces(to_combine));
458           }
459         }
460       }
461     }
462   }
463 
464   return changed;
465 }
466 
467 }  // namespace xla
468