1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/call_graph.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_query.h"
30 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35
36 namespace xla {
37 namespace {
38
39 // In SPMD mode, if there's a cross-replica all-reduce that produces the same
40 // value for all partitions, replaces it with a global all-reduce and then
41 // divide by the number of partitions. Depending on the topology and the
42 // implementation of the all-reduce for the backend, this may give a better
43 // performance.
ReplaceReplicatedAllReduce(HloModule * module,int64 replica_count,int64 partition_count)44 StatusOr<bool> ReplaceReplicatedAllReduce(HloModule* module,
45 int64 replica_count,
46 int64 partition_count) {
47 TF_ASSIGN_OR_RETURN(
48 auto replication_analysis,
49 HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
50
51 bool changed = false;
52 int64 next_channel = hlo_query::NextChannelId(*module);
53 for (auto computation : module->computations()) {
54 for (auto instruction : computation->instructions()) {
55 if (auto ar = DynCast<HloAllReduceInstruction>(instruction)) {
56 const Shape& shape = ar->shape();
57 if (ar->channel_id()) {
58 continue;
59 }
60 if (ar->replica_groups().size() > 1) {
61 continue;
62 }
63 if (shape.IsTuple() || shape.element_type() != F32) {
64 continue;
65 }
66 // We would need a cost model for the target, but in general we want to
67 // rewrite only if the replica count in the original op was large.
68 if (replica_count < 8 * partition_count) {
69 continue;
70 }
71 if (replication_analysis->HloInstructionIsReplicatedAt(ar, {})) {
72 VLOG(2) << "Replaced replicated all-reduce:" << ar->ToString();
73 ar->set_channel_id(next_channel++);
74 auto divisor =
75 computation->AddInstruction(HloInstruction::CreateConstant(
76 LiteralUtil::CreateR0<float>(partition_count)));
77 auto bcast = computation->AddInstruction(
78 HloInstruction::CreateBroadcast(shape, divisor, {}));
79 auto div = computation->AddInstruction(HloInstruction::CreateBinary(
80 ar->shape(), HloOpcode::kDivide, ar, bcast));
81 TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(div));
82 changed = true;
83 }
84 }
85 }
86 }
87 return changed;
88 }
89
90 // Returns true if the given instruction (must be a cross-partition all-reduce)
91 // has a ReplicaGroup config that can be combined with cross-replica all-reduce.
92 // We currently restrict to those groups where all partitions in each replica
93 // belong to the same group.
HasCombinableReplicaGroup(HloInstruction * hlo,int64 num_replicas,int64 num_partitions)94 bool HasCombinableReplicaGroup(HloInstruction* hlo, int64 num_replicas,
95 int64 num_partitions) {
96 auto all_reduce = Cast<HloAllReduceInstruction>(hlo);
97 auto replica_groups = all_reduce->replica_groups();
98 CHECK(all_reduce->IsCrossModuleAllReduce());
99
100 if (all_reduce->use_global_device_ids()) {
101 if (replica_groups.size() != num_replicas) {
102 return false;
103 }
104 for (const auto& group : replica_groups) {
105 if (group.replica_ids_size() != num_partitions) {
106 return false;
107 }
108 std::unordered_set<int64> partition_ids;
109 int64 replica_id = group.replica_ids(0) / num_partitions;
110 for (int64 i = 0; i < num_partitions; ++i) {
111 if (group.replica_ids(i) / num_partitions != replica_id) {
112 return false;
113 }
114 partition_ids.insert(group.replica_ids(i) % num_partitions);
115 }
116 if (partition_ids.size() != num_partitions) {
117 return false;
118 }
119 }
120 return true;
121 }
122
123 return replica_groups.size() == num_replicas;
124 }
125
126 } // namespace
127
128 namespace m = match;
129
130 // Checks if the argument instruction is an AllReduce, followed by a certain
131 // sequence of instructions and then a CRS. It must be possible to move
132 // the AR past each instruction in the sequence.
MatchesArCrsPattern(HloInstruction * instruction)133 absl::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
134 HloInstruction* instruction) {
135 auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
136 if (instruction->user_count() != 1) {
137 return false;
138 }
139 switch (instruction->opcode()) {
140 case HloOpcode::kBitcast:
141 case HloOpcode::kTranspose:
142 case HloOpcode::kReshape:
143 return true;
144 case HloOpcode::kConvert:
145 // Can be moved across if both input and output is either float or
146 // integer (e.g. S32<->U32 or F32<->BF16)
147 return ShapeUtil::ElementIsFloating(instruction->shape()) ==
148 ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
149 case HloOpcode::kAdd:
150 case HloOpcode::kSubtract:
151 case HloOpcode::kMultiply:
152 // Only supported for floating point operands.
153 return ShapeUtil::ElementIsFloating(instruction->shape());
154 default:
155 return false;
156 }
157 };
158
159 auto computation_is_addition = [](HloComputation* c) {
160 return c->instruction_count() == 3 &&
161 Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
162 };
163
164 // We only support combining cross-partition all-reduce where each replica
165 // belongs to its own group, since the later cross-replica all-reduce combines
166 // along the replica dimension.
167 if (instruction->IsCrossModuleAllReduce() &&
168 HasCombinableReplicaGroup(instruction, num_replicas_,
169 num_spatial_partitions_) &&
170 computation_is_addition(instruction->called_computations()[0]) &&
171 instruction->user_count() == 1) {
172 auto next = instruction->users()[0];
173 int64 distance = 1;
174 while (!next->IsCrossReplicaAllReduce()) {
175 if (can_ar_move_past_instruction(next)) {
176 next = next->users()[0];
177 } else {
178 return absl::nullopt;
179 }
180 ++distance;
181 }
182 if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
183 computation_is_addition(next->called_computations()[0])) {
184 ArCrsPair pair(instruction, next, distance);
185 VLOG(2) << "ArCrsPair matching pattern: " << pair.ToString();
186 return pair;
187 }
188 }
189 return absl::nullopt;
190 }
191
WhileFromBodyParameter(HloInstruction * instruction)192 absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
193 HloInstruction* instruction) {
194 CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
195 HloComputation* computation = instruction->parent();
196 auto caller_instructions = call_graph_->GetComputationCallers(computation);
197 if (caller_instructions.size() == 1) {
198 auto caller_instruction = caller_instructions[0];
199 if (caller_instruction->opcode() == HloOpcode::kWhile) {
200 return caller_instruction;
201 }
202 }
203 return absl::nullopt;
204 }
205
ConditionalFromBodyParameter(HloInstruction * instruction)206 absl::optional<HloInstruction*> ArCrsCombiner::ConditionalFromBodyParameter(
207 HloInstruction* instruction) {
208 CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
209 HloComputation* computation = instruction->parent();
210 auto caller_instructions = call_graph_->GetComputationCallers(computation);
211 if (caller_instructions.size() == 1) {
212 auto caller_instruction = caller_instructions[0];
213 if (caller_instruction->opcode() == HloOpcode::kConditional) {
214 return caller_instruction;
215 }
216 }
217 return absl::nullopt;
218 }
219
GetAllTuples(HloInstruction * instruction,absl::flat_hash_set<HloInstruction * > * visited)220 absl::optional<std::vector<HloInstruction*>> ArCrsCombiner::GetAllTuples(
221 HloInstruction* instruction,
222 absl::flat_hash_set<HloInstruction*>* visited) {
223 if (visited->find(instruction) != visited->end()) {
224 return std::vector<HloInstruction*>();
225 }
226 visited->insert(instruction);
227
228 switch (instruction->opcode()) {
229 case HloOpcode::kTuple: {
230 return std::vector<HloInstruction*>({instruction});
231 }
232 case HloOpcode::kDomain: {
233 return GetAllTuples(instruction->operands()[0], visited);
234 }
235 case HloOpcode::kParameter: {
236 auto maybe_while = WhileFromBodyParameter(instruction);
237 if (maybe_while) {
238 auto while_instr = *maybe_while;
239 auto init_tuples = GetAllTuples(while_instr->while_init(), visited);
240 auto body_tuples = GetAllTuples(
241 while_instr->while_body()->root_instruction(), visited);
242 if (!init_tuples || !body_tuples) {
243 return absl::nullopt;
244 }
245 auto result = *init_tuples;
246 result.insert(result.end(), body_tuples->begin(), body_tuples->end());
247 return result;
248 }
249 auto maybe_conditional = ConditionalFromBodyParameter(instruction);
250 if (maybe_conditional) {
251 auto cond_instr = *maybe_conditional;
252 std::vector<HloInstruction*> tuples;
253 for (int64 i = 0; i < cond_instr->branch_computations().size(); ++i) {
254 if (cond_instr->branch_computation(i)->parameter_instruction(0) ==
255 instruction) {
256 // If the same computation is used for more than one branch of the
257 // conditional, we collect the arguments that flow to the
258 // computation from all branches.
259 auto branch_tuples =
260 GetAllTuples(cond_instr->mutable_operand(i + 1), visited);
261 if (!branch_tuples) {
262 return absl::nullopt;
263 }
264 tuples.insert(tuples.end(), branch_tuples->begin(),
265 branch_tuples->end());
266 }
267 }
268 return tuples;
269 }
270 return absl::nullopt;
271 }
272 case HloOpcode::kGetTupleElement: {
273 std::vector<HloInstruction*> result_tuples;
274 auto tuples = GetAllTuples(instruction->operands()[0], visited);
275 if (!tuples) {
276 return absl::nullopt;
277 }
278 for (auto tuple : *tuples) {
279 auto tmp_tuples = GetAllTuples(
280 tuple->mutable_operand(instruction->tuple_index()), visited);
281 if (!tmp_tuples) {
282 return absl::nullopt;
283 }
284 result_tuples.insert(result_tuples.end(), tmp_tuples->begin(),
285 tmp_tuples->end());
286 }
287 return result_tuples;
288 }
289 case HloOpcode::kConditional: {
290 std::vector<HloInstruction*> result_tuples;
291 for (HloComputation* body : instruction->branch_computations()) {
292 if (body->root_instruction()->opcode() != HloOpcode::kTuple) {
293 return absl::nullopt;
294 }
295 result_tuples.push_back(body->root_instruction());
296 }
297 return result_tuples;
298 }
299 case HloOpcode::kWhile: {
300 auto init_tuples = GetAllTuples(instruction->while_init(), visited);
301 auto body_tuples =
302 GetAllTuples(instruction->while_body()->root_instruction(), visited);
303 if (!init_tuples || !body_tuples) {
304 return absl::nullopt;
305 }
306 auto result = *init_tuples;
307 result.insert(result.end(), body_tuples->begin(), body_tuples->end());
308 return result;
309 }
310 default:
311 return absl::nullopt;
312 }
313 }
314
TupleElementsComputeSameValue(HloInstruction * tuple_shaped_instruction,int64 i1,int64 i2,absl::flat_hash_map<int64,int64> * visited_pairs)315 bool ArCrsCombiner::TupleElementsComputeSameValue(
316 HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2,
317 absl::flat_hash_map<int64, int64>* visited_pairs) {
318 absl::flat_hash_set<HloInstruction*> visited;
319 auto tuples = GetAllTuples(tuple_shaped_instruction, &visited);
320 if (!tuples) {
321 return false;
322 }
323 for (auto tuple : *tuples) {
324 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
325 if (!InstructionsComputeSameValue(tuple->mutable_operand(i1),
326 tuple->mutable_operand(i2),
327 visited_pairs)) {
328 return false;
329 }
330 }
331 return true;
332 }
333
334 /* static */
TestInstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2)335 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
336 HloInstruction* i2) {
337 ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
338 /*spmd_partition=*/false);
339 auto module = i1->parent()->parent();
340 CHECK_EQ(module, i2->parent()->parent());
341 combiner.call_graph_ = CallGraph::Build(module);
342 absl::flat_hash_map<int64, int64> visited_pairs;
343 return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs);
344 }
345
InstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2,absl::flat_hash_map<int64,int64> * visited_pairs)346 bool ArCrsCombiner::InstructionsComputeSameValue(
347 HloInstruction* i1, HloInstruction* i2,
348 absl::flat_hash_map<int64, int64>* visited_pairs) {
349 if (i1 == i2) {
350 return true;
351 }
352 auto uid1 = i1->unique_id();
353 auto uid2 = i2->unique_id();
354 auto min_uid = std::min(uid1, uid2);
355 auto max_uid = std::max(uid1, uid2);
356 auto it = visited_pairs->find(min_uid);
357 if (it != visited_pairs->end() && max_uid == it->second) {
358 return true;
359 }
360 auto opcode1 = i1->opcode();
361 auto operands1 = i1->operands();
362 if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) {
363 return false;
364 }
365 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
366 return *a == *b;
367 };
368 // Two MPMD AllReduces are identical if they have the same channel_id. Their
369 // operands don't have to be identical.
370 auto eq_operands = [](const HloInstruction*, const HloInstruction*) {
371 return true;
372 };
373 if (i1->IsCrossModuleAllReduce()) {
374 return i1->Identical(*i2, eq_operands, eq_computations,
375 /*layout_sensitive=*/false);
376 }
377 visited_pairs->emplace(min_uid, max_uid);
378 for (int i = 0; i < operands1.size(); ++i) {
379 auto operand1 = operands1[i];
380 auto operand2 = i2->operands()[i];
381 if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) {
382 return false;
383 }
384 }
385 if (opcode1 == HloOpcode::kParameter) {
386 // In the general case, we don't try to prove equality of parameters.
387 // We only try in the context of get-tuple-element
388 // (see TupleElementsComputeSameValue).
389 return false;
390 }
391 if (opcode1 == HloOpcode::kGetTupleElement) {
392 return i1->tuple_index() == i2->tuple_index() ||
393 TupleElementsComputeSameValue(operands1[0], i1->tuple_index(),
394 i2->tuple_index(), visited_pairs);
395 }
396 // Don't check that the operands are identical, because Identical can
397 // return false for instructions that compute the same value but are not
398 // identical, which we don't want. We have checked the arguments with
399 // InstructionsComputeSameValue earlier.
400 auto eq_instructions = [](const HloInstruction* i1,
401 const HloInstruction* i2) -> bool { return true; };
402 return i1->Identical(*i2, eq_instructions, eq_computations,
403 /*layout_sensitive=*/false);
404 }
405
GroupAllReducesById(HloModule * module)406 void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
407 // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
408 // ... , (ARn, CRS).
409 // If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
410 // and later find that AR1's distance from the CRS is longer, we discard
411 // AR2 and start tracking AR1. We put the discarded ids in this set, in order
412 // to skip processing of short paths when we encounter the other ARs that
413 // have the same id as AR2.
414 absl::flat_hash_set<int64> discarded_ar_ids;
415 for (HloComputation* computation : module->MakeNonfusionComputations()) {
416 for (HloInstruction* instruction : computation->instructions()) {
417 auto maybe_pair = MatchesArCrsPattern(instruction);
418 if (maybe_pair) {
419 auto pair = *maybe_pair;
420 int64 ar_id = *(instruction->channel_id());
421 if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
422 continue;
423 }
424 auto it = crs_reserved_map_.find(pair.crs);
425 if (it != crs_reserved_map_.end()) {
426 auto prev_ar_id = it->second;
427 // Since there is another AR paired with CRS,
428 // all_reduce_map_[prev_ar_id] should exist, but
429 // all_reduce_map_[ar_id] shouldn't.
430 CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
431 CHECK_NE(prev_ar_id, ar_id);
432 auto prev_pair = all_reduce_map_[prev_ar_id].back();
433 int64 prev_distance = prev_pair.distance;
434 if (prev_distance < pair.distance) {
435 // The current AR's distance to CRS is longer than the previously
436 // tracked AR, so we discard the previous AR.
437 VLOG(2) << "Replacing ArCrsPair: " << prev_pair.ToString()
438 << " with ArCrsPair: " << pair.ToString();
439 all_reduce_map_.erase(prev_ar_id);
440 discarded_ar_ids.insert(prev_ar_id);
441 all_reduce_map_[ar_id].push_back(pair);
442 crs_reserved_map_[pair.crs] = ar_id;
443 } else {
444 // Discard the current AR id because we are keeping the previously
445 // tracked AR.
446 discarded_ar_ids.insert(ar_id);
447 }
448 } else {
449 if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
450 int64 prev_distance = all_reduce_map_[ar_id].back().distance;
451 CHECK_EQ(prev_distance, pair.distance)
452 << "All ARs with the same AR ID must have the same distance "
453 "from the corresponding CRSs. Found: "
454 << prev_distance << " and " << pair.distance;
455 }
456 all_reduce_map_[ar_id].push_back(pair);
457 crs_reserved_map_[pair.crs] = ar_id;
458 }
459 }
460 }
461 }
462 }
463
KeepProvablyEqualInstructionGroupsMPMD()464 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() {
465 for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
466 auto copy_it = it++; // Advance `it` before invalidation from erase.
467 auto channel_id = copy_it->first;
468 VLOG(2)
469 << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
470 << channel_id << "\n";
471 auto pairs_vec = copy_it->second;
472 TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_);
473 auto instr_0 = pairs_vec[0].ar;
474 for (int i = 1; i < pairs_vec.size(); ++i) {
475 auto instr_i = pairs_vec[i].ar;
476 auto next_0 = instr_0->users()[0];
477 auto next_i = instr_i->users()[0];
478 absl::flat_hash_map<int64, int64> visited_pairs;
479 while (true) {
480 if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
481 all_reduce_map_.erase(copy_it);
482 VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
483 "channel id: "
484 << channel_id << "\n";
485 break;
486 }
487 if (next_0->IsCrossReplicaAllReduce()) {
488 break;
489 }
490 next_0 = next_0->users()[0];
491 next_i = next_i->users()[0];
492 }
493 }
494 }
495 return Status::OK();
496 }
497
KeepProvablyEqualInstructionGroupsSPMD(HloModule * module)498 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD(
499 HloModule* module) {
500 // For SPMD mode, use HloReplicationAnalysis to figure out HLO value
501 // equivalence across partitions.
502 TF_ASSIGN_OR_RETURN(
503 auto replication_analysis,
504 HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
505
506 for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
507 auto copy_it = it++; // Advance `it` before invalidation from erase.
508 auto channel_id = copy_it->first;
509 VLOG(2)
510 << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
511 << channel_id << "\n";
512 auto pairs_vec = copy_it->second;
513 TF_RET_CHECK(pairs_vec.size() == 1);
514 auto instr = pairs_vec[0].ar;
515 auto next = instr->users()[0];
516 while (true) {
517 // The patterns we detect in ArCrsCombiner::MatchesArCrsPattern()
518 // guarantee that the HLO produces an array.
519 TF_RET_CHECK(next->shape().IsArray());
520 if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) {
521 all_reduce_map_.erase(copy_it);
522 VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
523 "channel id: "
524 << channel_id << "\n";
525 break;
526 }
527 if (next->IsCrossReplicaAllReduce()) {
528 break;
529 }
530 next = next->users()[0];
531 }
532 }
533 return Status::OK();
534 }
535
RewriteGraph()536 StatusOr<bool> ArCrsCombiner::RewriteGraph() {
537 if (all_reduce_map_.empty()) {
538 return false;
539 }
540 for (const auto& it : all_reduce_map_) {
541 auto pairs_vec = it.second;
542 for (auto pair : pairs_vec) {
543 auto all_reduce = pair.ar;
544 auto parent_computation = all_reduce->parent();
545 auto channel_id = all_reduce->channel_id();
546 auto prev = all_reduce->mutable_operand(0);
547 auto next = all_reduce->users()[0];
548 TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
549 TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
550 while (!next->IsCrossReplicaAllReduce()) {
551 switch (next->opcode()) {
552 case HloOpcode::kBitcast:
553 case HloOpcode::kTranspose:
554 case HloOpcode::kReshape:
555 case HloOpcode::kConvert:
556 case HloOpcode::kMultiply:
557 break;
558 case HloOpcode::kAdd:
559 case HloOpcode::kSubtract: {
560 auto other_operand = (next->operands()[0] == prev)
561 ? next->operands()[1]
562 : next->operands()[0];
563 // To move the AR past the addition/subtraction, we need to divide
564 // other_operand by the number of spatial partitions, except if
565 // other_operand is a cross-module AR, which can be eliminated.
566 if (other_operand->IsCrossModuleAllReduce() &&
567 other_operand->user_count() == 1) {
568 TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
569 other_operand->mutable_operand(0)));
570 } else {
571 auto shape = other_operand->shape();
572 Literal lit(shape);
573 lit.PopulateWithValue<float>(num_spatial_partitions_);
574 auto divisor = parent_computation->AddInstruction(
575 HloInstruction::CreateConstant(lit.Clone()));
576 auto division = parent_computation->AddInstruction(
577 HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
578 other_operand, divisor));
579 TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
580 }
581 break;
582 }
583 default:
584 LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
585 }
586 prev = next;
587 next = next->users()[0];
588 }
589 // The AllReduce and the CRS are combined to an all-core AllReduce.
590 //
591 // Note that we can just reuse the ReplicaGroup config of cross-replica
592 // all-reduce since we already checked that cross-partition all-reduce
593 // is always across all partitions (HasCombinableReplicaGroup). We need to
594 // combine ReplicaGroup configs using global ids here if we relax that
595 // restriction.
596 next->set_channel_id(channel_id);
597 }
598 }
599 return true;
600 }
601
Run(HloModule * module)602 StatusOr<bool> ArCrsCombiner::Run(HloModule* module) {
603 call_graph_ = CallGraph::Build(module);
604
605 GroupAllReducesById(module);
606
607 if (spmd_partition_) {
608 TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module));
609 } else {
610 TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD());
611 }
612
613 TF_ASSIGN_OR_RETURN(auto changed, RewriteGraph());
614
615 if (num_replicas_ > 1 && spmd_partition_) {
616 TF_ASSIGN_OR_RETURN(auto replaced,
617 ReplaceReplicatedAllReduce(module, num_replicas_,
618 num_spatial_partitions_));
619 changed |= replaced;
620 }
621
622 return changed;
623 }
624
625 } // namespace xla
626