1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/call_graph.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_query.h"
30 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 // In SPMD mode, if there's a cross-replica all-reduce that produces the same
40 // value for all partitions, replaces it with a global all-reduce and then
41 // divide by the number of partitions. Depending on the topology and the
42 // implementation of the all-reduce for the backend, this may give a better
43 // performance.
ReplaceReplicatedAllReduce(HloModule * module,int64 replica_count,int64 partition_count)44 StatusOr<bool> ReplaceReplicatedAllReduce(HloModule* module,
45                                           int64 replica_count,
46                                           int64 partition_count) {
47   TF_ASSIGN_OR_RETURN(
48       auto replication_analysis,
49       HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
50 
51   bool changed = false;
52   int64 next_channel = hlo_query::NextChannelId(*module);
53   for (auto computation : module->computations()) {
54     for (auto instruction : computation->instructions()) {
55       if (auto ar = DynCast<HloAllReduceInstruction>(instruction)) {
56         const Shape& shape = ar->shape();
57         if (ar->channel_id()) {
58           continue;
59         }
60         if (ar->replica_groups().size() > 1) {
61           continue;
62         }
63         if (shape.IsTuple() || shape.element_type() != F32) {
64           continue;
65         }
66         // We would need a cost model for the target, but in general we want to
67         // rewrite only if the replica count in the original op was large.
68         if (replica_count < 8 * partition_count) {
69           continue;
70         }
71         if (replication_analysis->HloInstructionIsReplicatedAt(ar, {})) {
72           VLOG(2) << "Replaced replicated all-reduce:" << ar->ToString();
73           ar->set_channel_id(next_channel++);
74           auto divisor =
75               computation->AddInstruction(HloInstruction::CreateConstant(
76                   LiteralUtil::CreateR0<float>(partition_count)));
77           auto bcast = computation->AddInstruction(
78               HloInstruction::CreateBroadcast(shape, divisor, {}));
79           auto div = computation->AddInstruction(HloInstruction::CreateBinary(
80               ar->shape(), HloOpcode::kDivide, ar, bcast));
81           TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(div));
82           changed = true;
83         }
84       }
85     }
86   }
87   return changed;
88 }
89 
90 // Returns true if the given instruction (must be a cross-partition all-reduce)
91 // has a ReplicaGroup config that can be combined with cross-replica all-reduce.
92 // We currently restrict to those groups where all partitions in each replica
93 // belong to the same group.
HasCombinableReplicaGroup(HloInstruction * hlo,int64 num_replicas,int64 num_partitions)94 bool HasCombinableReplicaGroup(HloInstruction* hlo, int64 num_replicas,
95                                int64 num_partitions) {
96   auto all_reduce = Cast<HloAllReduceInstruction>(hlo);
97   auto replica_groups = all_reduce->replica_groups();
98   CHECK(all_reduce->IsCrossModuleAllReduce());
99 
100   if (all_reduce->use_global_device_ids()) {
101     if (replica_groups.size() != num_replicas) {
102       return false;
103     }
104     for (const auto& group : replica_groups) {
105       if (group.replica_ids_size() != num_partitions) {
106         return false;
107       }
108       std::unordered_set<int64> partition_ids;
109       int64 replica_id = group.replica_ids(0) / num_partitions;
110       for (int64 i = 0; i < num_partitions; ++i) {
111         if (group.replica_ids(i) / num_partitions != replica_id) {
112           return false;
113         }
114         partition_ids.insert(group.replica_ids(i) % num_partitions);
115       }
116       if (partition_ids.size() != num_partitions) {
117         return false;
118       }
119     }
120     return true;
121   }
122 
123   return replica_groups.size() == num_replicas;
124 }
125 
126 }  // namespace
127 
128 namespace m = match;
129 
130 // Checks if the argument instruction is an AllReduce, followed by a certain
131 // sequence of instructions and then a CRS. It must be possible to move
132 // the AR past each instruction in the sequence.
MatchesArCrsPattern(HloInstruction * instruction)133 absl::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
134     HloInstruction* instruction) {
135   auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
136     if (instruction->user_count() != 1) {
137       return false;
138     }
139     switch (instruction->opcode()) {
140       case HloOpcode::kBitcast:
141       case HloOpcode::kTranspose:
142       case HloOpcode::kReshape:
143         return true;
144       case HloOpcode::kConvert:
145         // Can be moved across if both input and output is either float or
146         // integer (e.g. S32<->U32 or F32<->BF16)
147         return ShapeUtil::ElementIsFloating(instruction->shape()) ==
148                ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
149       case HloOpcode::kAdd:
150       case HloOpcode::kSubtract:
151       case HloOpcode::kMultiply:
152         // Only supported for floating point operands.
153         return ShapeUtil::ElementIsFloating(instruction->shape());
154       default:
155         return false;
156     }
157   };
158 
159   auto computation_is_addition = [](HloComputation* c) {
160     return c->instruction_count() == 3 &&
161            Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
162   };
163 
164   // We only support combining cross-partition all-reduce where each replica
165   // belongs to its own group, since the later cross-replica all-reduce combines
166   // along the replica dimension.
167   if (instruction->IsCrossModuleAllReduce() &&
168       HasCombinableReplicaGroup(instruction, num_replicas_,
169                                 num_spatial_partitions_) &&
170       computation_is_addition(instruction->called_computations()[0]) &&
171       instruction->user_count() == 1) {
172     auto next = instruction->users()[0];
173     int64 distance = 1;
174     while (!next->IsCrossReplicaAllReduce()) {
175       if (can_ar_move_past_instruction(next)) {
176         next = next->users()[0];
177       } else {
178         return absl::nullopt;
179       }
180       ++distance;
181     }
182     if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
183         computation_is_addition(next->called_computations()[0])) {
184       ArCrsPair pair(instruction, next, distance);
185       VLOG(2) << "ArCrsPair matching pattern: " << pair.ToString();
186       return pair;
187     }
188   }
189   return absl::nullopt;
190 }
191 
WhileFromBodyParameter(HloInstruction * instruction)192 absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
193     HloInstruction* instruction) {
194   CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
195   HloComputation* computation = instruction->parent();
196   auto caller_instructions = call_graph_->GetComputationCallers(computation);
197   if (caller_instructions.size() == 1) {
198     auto caller_instruction = caller_instructions[0];
199     if (caller_instruction->opcode() == HloOpcode::kWhile) {
200       return caller_instruction;
201     }
202   }
203   return absl::nullopt;
204 }
205 
ConditionalFromBodyParameter(HloInstruction * instruction)206 absl::optional<HloInstruction*> ArCrsCombiner::ConditionalFromBodyParameter(
207     HloInstruction* instruction) {
208   CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
209   HloComputation* computation = instruction->parent();
210   auto caller_instructions = call_graph_->GetComputationCallers(computation);
211   if (caller_instructions.size() == 1) {
212     auto caller_instruction = caller_instructions[0];
213     if (caller_instruction->opcode() == HloOpcode::kConditional) {
214       return caller_instruction;
215     }
216   }
217   return absl::nullopt;
218 }
219 
GetAllTuples(HloInstruction * instruction,absl::flat_hash_set<HloInstruction * > * visited)220 absl::optional<std::vector<HloInstruction*>> ArCrsCombiner::GetAllTuples(
221     HloInstruction* instruction,
222     absl::flat_hash_set<HloInstruction*>* visited) {
223   if (visited->find(instruction) != visited->end()) {
224     return std::vector<HloInstruction*>();
225   }
226   visited->insert(instruction);
227 
228   switch (instruction->opcode()) {
229     case HloOpcode::kTuple: {
230       return std::vector<HloInstruction*>({instruction});
231     }
232     case HloOpcode::kDomain: {
233       return GetAllTuples(instruction->operands()[0], visited);
234     }
235     case HloOpcode::kParameter: {
236       auto maybe_while = WhileFromBodyParameter(instruction);
237       if (maybe_while) {
238         auto while_instr = *maybe_while;
239         auto init_tuples = GetAllTuples(while_instr->while_init(), visited);
240         auto body_tuples = GetAllTuples(
241             while_instr->while_body()->root_instruction(), visited);
242         if (!init_tuples || !body_tuples) {
243           return absl::nullopt;
244         }
245         auto result = *init_tuples;
246         result.insert(result.end(), body_tuples->begin(), body_tuples->end());
247         return result;
248       }
249       auto maybe_conditional = ConditionalFromBodyParameter(instruction);
250       if (maybe_conditional) {
251         auto cond_instr = *maybe_conditional;
252         std::vector<HloInstruction*> tuples;
253         for (int64 i = 0; i < cond_instr->branch_computations().size(); ++i) {
254           if (cond_instr->branch_computation(i)->parameter_instruction(0) ==
255               instruction) {
256             // If the same computation is used for more than one branch of the
257             // conditional, we collect the arguments that flow to the
258             // computation from all branches.
259             auto branch_tuples =
260                 GetAllTuples(cond_instr->mutable_operand(i + 1), visited);
261             if (!branch_tuples) {
262               return absl::nullopt;
263             }
264             tuples.insert(tuples.end(), branch_tuples->begin(),
265                           branch_tuples->end());
266           }
267         }
268         return tuples;
269       }
270       return absl::nullopt;
271     }
272     case HloOpcode::kGetTupleElement: {
273       std::vector<HloInstruction*> result_tuples;
274       auto tuples = GetAllTuples(instruction->operands()[0], visited);
275       if (!tuples) {
276         return absl::nullopt;
277       }
278       for (auto tuple : *tuples) {
279         auto tmp_tuples = GetAllTuples(
280             tuple->mutable_operand(instruction->tuple_index()), visited);
281         if (!tmp_tuples) {
282           return absl::nullopt;
283         }
284         result_tuples.insert(result_tuples.end(), tmp_tuples->begin(),
285                              tmp_tuples->end());
286       }
287       return result_tuples;
288     }
289     case HloOpcode::kConditional: {
290       std::vector<HloInstruction*> result_tuples;
291       for (HloComputation* body : instruction->branch_computations()) {
292         if (body->root_instruction()->opcode() != HloOpcode::kTuple) {
293           return absl::nullopt;
294         }
295         result_tuples.push_back(body->root_instruction());
296       }
297       return result_tuples;
298     }
299     case HloOpcode::kWhile: {
300       auto init_tuples = GetAllTuples(instruction->while_init(), visited);
301       auto body_tuples =
302           GetAllTuples(instruction->while_body()->root_instruction(), visited);
303       if (!init_tuples || !body_tuples) {
304         return absl::nullopt;
305       }
306       auto result = *init_tuples;
307       result.insert(result.end(), body_tuples->begin(), body_tuples->end());
308       return result;
309     }
310     default:
311       return absl::nullopt;
312   }
313 }
314 
TupleElementsComputeSameValue(HloInstruction * tuple_shaped_instruction,int64 i1,int64 i2,absl::flat_hash_map<int64,int64> * visited_pairs)315 bool ArCrsCombiner::TupleElementsComputeSameValue(
316     HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2,
317     absl::flat_hash_map<int64, int64>* visited_pairs) {
318   absl::flat_hash_set<HloInstruction*> visited;
319   auto tuples = GetAllTuples(tuple_shaped_instruction, &visited);
320   if (!tuples) {
321     return false;
322   }
323   for (auto tuple : *tuples) {
324     CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
325     if (!InstructionsComputeSameValue(tuple->mutable_operand(i1),
326                                       tuple->mutable_operand(i2),
327                                       visited_pairs)) {
328       return false;
329     }
330   }
331   return true;
332 }
333 
334 /* static */
TestInstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2)335 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
336                                                      HloInstruction* i2) {
337   ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
338                          /*spmd_partition=*/false);
339   auto module = i1->parent()->parent();
340   CHECK_EQ(module, i2->parent()->parent());
341   combiner.call_graph_ = CallGraph::Build(module);
342   absl::flat_hash_map<int64, int64> visited_pairs;
343   return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs);
344 }
345 
InstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2,absl::flat_hash_map<int64,int64> * visited_pairs)346 bool ArCrsCombiner::InstructionsComputeSameValue(
347     HloInstruction* i1, HloInstruction* i2,
348     absl::flat_hash_map<int64, int64>* visited_pairs) {
349   if (i1 == i2) {
350     return true;
351   }
352   auto uid1 = i1->unique_id();
353   auto uid2 = i2->unique_id();
354   auto min_uid = std::min(uid1, uid2);
355   auto max_uid = std::max(uid1, uid2);
356   auto it = visited_pairs->find(min_uid);
357   if (it != visited_pairs->end() && max_uid == it->second) {
358     return true;
359   }
360   auto opcode1 = i1->opcode();
361   auto operands1 = i1->operands();
362   if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) {
363     return false;
364   }
365   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
366     return *a == *b;
367   };
368   // Two MPMD AllReduces are identical if they have the same channel_id. Their
369   // operands don't have to be identical.
370   auto eq_operands = [](const HloInstruction*, const HloInstruction*) {
371     return true;
372   };
373   if (i1->IsCrossModuleAllReduce()) {
374     return i1->Identical(*i2, eq_operands, eq_computations,
375                          /*layout_sensitive=*/false);
376   }
377   visited_pairs->emplace(min_uid, max_uid);
378   for (int i = 0; i < operands1.size(); ++i) {
379     auto operand1 = operands1[i];
380     auto operand2 = i2->operands()[i];
381     if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) {
382       return false;
383     }
384   }
385   if (opcode1 == HloOpcode::kParameter) {
386     // In the general case, we don't try to prove equality of parameters.
387     // We only try in the context of get-tuple-element
388     // (see TupleElementsComputeSameValue).
389     return false;
390   }
391   if (opcode1 == HloOpcode::kGetTupleElement) {
392     return i1->tuple_index() == i2->tuple_index() ||
393            TupleElementsComputeSameValue(operands1[0], i1->tuple_index(),
394                                          i2->tuple_index(), visited_pairs);
395   }
396   // Don't check that the operands are identical, because Identical can
397   // return false for instructions that compute the same value but are not
398   // identical, which we don't want. We have checked the arguments with
399   // InstructionsComputeSameValue earlier.
400   auto eq_instructions = [](const HloInstruction* i1,
401                             const HloInstruction* i2) -> bool { return true; };
402   return i1->Identical(*i2, eq_instructions, eq_computations,
403                        /*layout_sensitive=*/false);
404 }
405 
GroupAllReducesById(HloModule * module)406 void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
407   // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
408   // ... , (ARn, CRS).
409   // If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
410   // and later find that AR1's distance from the CRS is longer, we discard
411   // AR2 and start tracking AR1. We put the discarded ids in this set, in order
412   // to skip processing of short paths when we encounter the other ARs that
413   // have the same id as AR2.
414   absl::flat_hash_set<int64> discarded_ar_ids;
415   for (HloComputation* computation : module->MakeNonfusionComputations()) {
416     for (HloInstruction* instruction : computation->instructions()) {
417       auto maybe_pair = MatchesArCrsPattern(instruction);
418       if (maybe_pair) {
419         auto pair = *maybe_pair;
420         int64 ar_id = *(instruction->channel_id());
421         if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
422           continue;
423         }
424         auto it = crs_reserved_map_.find(pair.crs);
425         if (it != crs_reserved_map_.end()) {
426           auto prev_ar_id = it->second;
427           // Since there is another AR paired with CRS,
428           // all_reduce_map_[prev_ar_id] should exist, but
429           // all_reduce_map_[ar_id] shouldn't.
430           CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
431           CHECK_NE(prev_ar_id, ar_id);
432           auto prev_pair = all_reduce_map_[prev_ar_id].back();
433           int64 prev_distance = prev_pair.distance;
434           if (prev_distance < pair.distance) {
435             // The current AR's distance to CRS is longer than the previously
436             // tracked AR, so we discard the previous AR.
437             VLOG(2) << "Replacing ArCrsPair: " << prev_pair.ToString()
438                     << " with ArCrsPair: " << pair.ToString();
439             all_reduce_map_.erase(prev_ar_id);
440             discarded_ar_ids.insert(prev_ar_id);
441             all_reduce_map_[ar_id].push_back(pair);
442             crs_reserved_map_[pair.crs] = ar_id;
443           } else {
444             // Discard the current AR id because we are keeping the previously
445             // tracked AR.
446             discarded_ar_ids.insert(ar_id);
447           }
448         } else {
449           if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
450             int64 prev_distance = all_reduce_map_[ar_id].back().distance;
451             CHECK_EQ(prev_distance, pair.distance)
452                 << "All ARs with the same AR ID must have the same distance "
453                    "from the corresponding CRSs. Found: "
454                 << prev_distance << " and " << pair.distance;
455           }
456           all_reduce_map_[ar_id].push_back(pair);
457           crs_reserved_map_[pair.crs] = ar_id;
458         }
459       }
460     }
461   }
462 }
463 
KeepProvablyEqualInstructionGroupsMPMD()464 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() {
465   for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
466     auto copy_it = it++;  // Advance `it` before invalidation from erase.
467     auto channel_id = copy_it->first;
468     VLOG(2)
469         << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
470         << channel_id << "\n";
471     auto pairs_vec = copy_it->second;
472     TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_);
473     auto instr_0 = pairs_vec[0].ar;
474     for (int i = 1; i < pairs_vec.size(); ++i) {
475       auto instr_i = pairs_vec[i].ar;
476       auto next_0 = instr_0->users()[0];
477       auto next_i = instr_i->users()[0];
478       absl::flat_hash_map<int64, int64> visited_pairs;
479       while (true) {
480         if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
481           all_reduce_map_.erase(copy_it);
482           VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
483                      "channel id: "
484                   << channel_id << "\n";
485           break;
486         }
487         if (next_0->IsCrossReplicaAllReduce()) {
488           break;
489         }
490         next_0 = next_0->users()[0];
491         next_i = next_i->users()[0];
492       }
493     }
494   }
495   return Status::OK();
496 }
497 
KeepProvablyEqualInstructionGroupsSPMD(HloModule * module)498 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD(
499     HloModule* module) {
500   // For SPMD mode, use HloReplicationAnalysis to figure out HLO value
501   // equivalence across partitions.
502   TF_ASSIGN_OR_RETURN(
503       auto replication_analysis,
504       HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
505 
506   for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
507     auto copy_it = it++;  // Advance `it` before invalidation from erase.
508     auto channel_id = copy_it->first;
509     VLOG(2)
510         << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
511         << channel_id << "\n";
512     auto pairs_vec = copy_it->second;
513     TF_RET_CHECK(pairs_vec.size() == 1);
514     auto instr = pairs_vec[0].ar;
515     auto next = instr->users()[0];
516     while (true) {
517       // The patterns we detect in ArCrsCombiner::MatchesArCrsPattern()
518       // guarantee that the HLO produces an array.
519       TF_RET_CHECK(next->shape().IsArray());
520       if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) {
521         all_reduce_map_.erase(copy_it);
522         VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
523                    "channel id: "
524                 << channel_id << "\n";
525         break;
526       }
527       if (next->IsCrossReplicaAllReduce()) {
528         break;
529       }
530       next = next->users()[0];
531     }
532   }
533   return Status::OK();
534 }
535 
RewriteGraph()536 StatusOr<bool> ArCrsCombiner::RewriteGraph() {
537   if (all_reduce_map_.empty()) {
538     return false;
539   }
540   for (const auto& it : all_reduce_map_) {
541     auto pairs_vec = it.second;
542     for (auto pair : pairs_vec) {
543       auto all_reduce = pair.ar;
544       auto parent_computation = all_reduce->parent();
545       auto channel_id = all_reduce->channel_id();
546       auto prev = all_reduce->mutable_operand(0);
547       auto next = all_reduce->users()[0];
548       TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
549       TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
550       while (!next->IsCrossReplicaAllReduce()) {
551         switch (next->opcode()) {
552           case HloOpcode::kBitcast:
553           case HloOpcode::kTranspose:
554           case HloOpcode::kReshape:
555           case HloOpcode::kConvert:
556           case HloOpcode::kMultiply:
557             break;
558           case HloOpcode::kAdd:
559           case HloOpcode::kSubtract: {
560             auto other_operand = (next->operands()[0] == prev)
561                                      ? next->operands()[1]
562                                      : next->operands()[0];
563             // To move the AR past the addition/subtraction, we need to divide
564             // other_operand by the number of spatial partitions, except if
565             // other_operand is a cross-module AR, which can be eliminated.
566             if (other_operand->IsCrossModuleAllReduce() &&
567                 other_operand->user_count() == 1) {
568               TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
569                   other_operand->mutable_operand(0)));
570             } else {
571               auto shape = other_operand->shape();
572               Literal lit(shape);
573               lit.PopulateWithValue<float>(num_spatial_partitions_);
574               auto divisor = parent_computation->AddInstruction(
575                   HloInstruction::CreateConstant(lit.Clone()));
576               auto division = parent_computation->AddInstruction(
577                   HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
578                                                other_operand, divisor));
579               TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
580             }
581             break;
582           }
583           default:
584             LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
585         }
586         prev = next;
587         next = next->users()[0];
588       }
589       // The AllReduce and the CRS are combined to an all-core AllReduce.
590       //
591       // Note that we can just reuse the ReplicaGroup config of cross-replica
592       // all-reduce since we already checked that cross-partition all-reduce
593       // is always across all partitions (HasCombinableReplicaGroup). We need to
594       // combine ReplicaGroup configs using global ids here if we relax that
595       // restriction.
596       next->set_channel_id(channel_id);
597     }
598   }
599   return true;
600 }
601 
Run(HloModule * module)602 StatusOr<bool> ArCrsCombiner::Run(HloModule* module) {
603   call_graph_ = CallGraph::Build(module);
604 
605   GroupAllReducesById(module);
606 
607   if (spmd_partition_) {
608     TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module));
609   } else {
610     TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD());
611   }
612 
613   TF_ASSIGN_OR_RETURN(auto changed, RewriteGraph());
614 
615   if (num_replicas_ > 1 && spmd_partition_) {
616     TF_ASSIGN_OR_RETURN(auto replaced,
617                         ReplaceReplicatedAllReduce(module, num_replicas_,
618                                                    num_spatial_partitions_));
619     changed |= replaced;
620   }
621 
622   return changed;
623 }
624 
625 }  // namespace xla
626