1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/call_graph.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 
33 namespace xla {
34 
35 namespace m = match;
36 
37 // Checks if the argument instruction is an AllReduce, followed by a certain
38 // sequence of instructions and then a CRS. It must be possible to move
39 // the AR past each instruction in the sequence. Returns the CRS, which is the
40 // last instruction in the sequence.
MatchesArCrsPattern(HloInstruction * instruction)41 absl::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
42     HloInstruction* instruction) {
43   auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
44     if (instruction->user_count() != 1) {
45       return false;
46     }
47     switch (instruction->opcode()) {
48       case HloOpcode::kBitcast:
49       case HloOpcode::kTranspose:
50       case HloOpcode::kReshape:
51         return true;
52       case HloOpcode::kConvert:
53         // Can be moved across if both input and output is either float or
54         // integer (e.g. S32<->U32 or F32<->BF16)
55         return ShapeUtil::ElementIsFloating(instruction->shape()) ==
56                ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
57       case HloOpcode::kAdd:
58       case HloOpcode::kSubtract:
59       case HloOpcode::kMultiply:
60         // Only supported for floating point operands.
61         return ShapeUtil::ElementIsFloating(instruction->shape());
62       default:
63         return false;
64     }
65   };
66 
67   auto computation_is_addition = [](HloComputation* c) {
68     return c->instruction_count() == 3 &&
69            Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
70   };
71 
72   if (!instruction->IsCrossModuleAllReduce() ||
73       !computation_is_addition(instruction->called_computations()[0]) ||
74       instruction->user_count() != 1) {
75     return absl::nullopt;
76   }
77   auto next = instruction->users()[0];
78   int64 distance = 1;
79   while (!next->IsCrossReplicaAllReduce()) {
80     if (can_ar_move_past_instruction(next)) {
81       next = next->users()[0];
82     } else {
83       return absl::nullopt;
84     }
85     ++distance;
86   }
87   if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
88       computation_is_addition(next->called_computations()[0])) {
89     return absl::optional<ArCrsPair>(ArCrsPair(instruction, next, distance));
90   } else {
91     return absl::nullopt;
92   }
93 }
94 
WhileFromBodyParameter(HloInstruction * instruction)95 absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
96     HloInstruction* instruction) {
97   CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
98   HloComputation* computation = instruction->parent();
99   auto caller_instructions = call_graph_->GetComputationCallers(computation);
100   if (caller_instructions.size() == 1) {
101     auto caller_instruction = caller_instructions[0];
102     if (caller_instruction->opcode() == HloOpcode::kWhile) {
103       return caller_instruction;
104     }
105   }
106   return absl::nullopt;
107 }
108 
GetAllTuples(HloInstruction * instruction)109 std::vector<HloInstruction*> ArCrsCombiner::GetAllTuples(
110     HloInstruction* instruction) {
111   if (instruction->opcode() == HloOpcode::kTuple) {
112     return {instruction};
113   }
114   if (instruction->opcode() == HloOpcode::kDomain) {
115     return GetAllTuples(instruction->operands()[0]);
116   }
117   if (instruction->opcode() == HloOpcode::kParameter) {
118     auto maybe_while = WhileFromBodyParameter(instruction);
119     if (!maybe_while) {
120       return {};
121     }
122     auto while_instr = *maybe_while;
123     auto init_tuples = GetAllTuples(while_instr->while_init());
124     auto body_tuples =
125         GetAllTuples(while_instr->while_body()->root_instruction());
126     if (init_tuples.empty() || body_tuples.empty()) {
127       return {};
128     }
129     init_tuples.insert(init_tuples.end(), body_tuples.begin(),
130                        body_tuples.end());
131     return init_tuples;
132   }
133   if (instruction->opcode() == HloOpcode::kGetTupleElement) {
134     std::vector<HloInstruction*> result_tuples;
135     for (auto tuple : GetAllTuples(instruction->operands()[0])) {
136       auto tmp_tuples =
137           GetAllTuples(tuple->mutable_operand(instruction->tuple_index()));
138       if (tmp_tuples.empty()) {
139         return {};
140       }
141       result_tuples.insert(result_tuples.end(), tmp_tuples.begin(),
142                            tmp_tuples.end());
143     }
144     return result_tuples;
145   }
146   return {};
147 }
148 
TupleElementsComputeSameValue(HloInstruction * tuple_shaped_instruction,int64 i1,int64 i2,absl::flat_hash_map<int64,int64> * visited_pairs)149 bool ArCrsCombiner::TupleElementsComputeSameValue(
150     HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2,
151     absl::flat_hash_map<int64, int64>* visited_pairs) {
152   auto tuples = GetAllTuples(tuple_shaped_instruction);
153   if (tuples.empty()) {
154     return false;
155   }
156   for (auto tuple : tuples) {
157     CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
158     if (!InstructionsComputeSameValue(tuple->mutable_operand(i1),
159                                       tuple->mutable_operand(i2),
160                                       visited_pairs)) {
161       return false;
162     }
163   }
164   return true;
165 }
166 
167 /* static */
TestInstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2)168 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
169                                                      HloInstruction* i2) {
170   ArCrsCombiner combiner(/*num_spatial_partitions=*/2);
171   auto module = i1->parent()->parent();
172   CHECK_EQ(module, i2->parent()->parent());
173   combiner.call_graph_ = CallGraph::Build(module);
174   absl::flat_hash_map<int64, int64> visited_pairs;
175   return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs);
176 }
177 
InstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2,absl::flat_hash_map<int64,int64> * visited_pairs)178 bool ArCrsCombiner::InstructionsComputeSameValue(
179     HloInstruction* i1, HloInstruction* i2,
180     absl::flat_hash_map<int64, int64>* visited_pairs) {
181   if (i1 == i2) {
182     return true;
183   }
184   auto uid1 = i1->unique_id();
185   auto uid2 = i2->unique_id();
186   auto min_uid = std::min(uid1, uid2);
187   auto max_uid = std::max(uid1, uid2);
188   auto it = visited_pairs->find(min_uid);
189   if (it != visited_pairs->end() && max_uid == it->second) {
190     return true;
191   }
192   auto opcode1 = i1->opcode();
193   auto operands1 = i1->operands();
194   if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) {
195     return false;
196   }
197   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
198     return *a == *b;
199   };
200   if (i1->IsCrossModuleAllReduce()) {
201     return i1->Identical(*i2,
202                          /*eq_operands=*/std::equal_to<const HloInstruction*>(),
203                          eq_computations,
204                          /*layout_sensitive=*/false);
205   }
206   visited_pairs->emplace(min_uid, max_uid);
207   for (int i = 0; i < operands1.size(); ++i) {
208     auto operand1 = operands1[i];
209     auto operand2 = i2->operands()[i];
210     if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) {
211       return false;
212     }
213   }
214   if (opcode1 == HloOpcode::kParameter) {
215     // In the general case, we don't try to prove equality of parameters.
216     // We only try in the context of get-tuple-element
217     // (see TupleElementsComputeSameValue).
218     return false;
219   }
220   if (opcode1 == HloOpcode::kGetTupleElement) {
221     return i1->tuple_index() == i2->tuple_index() ||
222            TupleElementsComputeSameValue(operands1[0], i1->tuple_index(),
223                                          i2->tuple_index(), visited_pairs);
224   }
225   // Don't check that the operands are identical, because Identical can
226   // return false for instructions that compute the same value but are not
227   // identical, which we don't want. We have checked the arguments with
228   // InstructionsComputeSameValue earlier.
229   auto eq_instructions = [](const HloInstruction* i1,
230                             const HloInstruction* i2) -> bool { return true; };
231   return i1->Identical(*i2, eq_instructions, eq_computations,
232                        /*layout_sensitive=*/false);
233 }
234 
GroupAllReducesById(HloModule * module)235 void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
236   // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
237   // ... , (ARn, CRS).
238   // If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
239   // and later find that AR1's distance from the CRS is longer, we discard
240   // AR2 and start tracking AR1. We put the discarded ids in this set, in order
241   // to skip processing of short paths when we encounter the other ARs that
242   // have the same id as AR2.
243   absl::flat_hash_set<int64> discarded_ar_ids;
244   for (HloComputation* computation : module->MakeNonfusionComputations()) {
245     for (HloInstruction* instruction : computation->instructions()) {
246       auto maybe_pair = MatchesArCrsPattern(instruction);
247       if (maybe_pair) {
248         auto pair = *maybe_pair;
249         int64 ar_id = *(instruction->all_reduce_id());
250         if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
251           continue;
252         }
253         auto it = crs_reserved_map_.find(pair.crs);
254         if (it != crs_reserved_map_.end()) {
255           auto prev_ar_id = it->second;
256           // Since there is another AR paired with CRS,
257           // all_reduce_map_[prev_ar_id] should exist, but
258           // all_reduce_map_[ar_id] shouldn't.
259           CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
260           CHECK_NE(prev_ar_id, ar_id);
261           auto prev_pair = all_reduce_map_[prev_ar_id].back();
262           int64 prev_distance = prev_pair.distance;
263           if (prev_distance < pair.distance) {
264             // The current AR's distance to CRS is longer than the previously
265             // tracked AR, so we discard the previous AR.
266             all_reduce_map_.erase(prev_ar_id);
267             discarded_ar_ids.insert(prev_ar_id);
268             all_reduce_map_[ar_id].push_back(pair);
269             crs_reserved_map_[pair.crs] = ar_id;
270           } else {
271             // Discard the current AR id because we are keeping the previously
272             // tracked AR.
273             discarded_ar_ids.insert(ar_id);
274           }
275         } else {
276           if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
277             int64 prev_distance = all_reduce_map_[ar_id].back().distance;
278             CHECK_EQ(prev_distance, pair.distance)
279                 << "All ARs with the same AR ID must have the same distance "
280                    "from the corresponding CRSs. Found: "
281                 << prev_distance << " and " << pair.distance;
282           }
283           all_reduce_map_[ar_id].push_back(pair);
284           crs_reserved_map_[pair.crs] = ar_id;
285         }
286       }
287     }
288   }
289 }
290 
KeepProvablyEqualInstructionGroups()291 void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
292   for (auto it : all_reduce_map_) {
293     auto all_reduce_id = it.first;
294     auto pairs_vec = it.second;
295     CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
296     auto instr_0 = pairs_vec[0].ar;
297     for (int i = 1; i < pairs_vec.size(); ++i) {
298       auto instr_i = pairs_vec[i].ar;
299       auto next_0 = instr_0->users()[0];
300       auto next_i = instr_i->users()[0];
301       absl::flat_hash_map<int64, int64> visited_pairs;
302       while (true) {
303         if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
304           all_reduce_map_.erase(all_reduce_id);
305           break;
306         }
307         if (next_0->IsCrossReplicaAllReduce()) {
308           break;
309         }
310         next_0 = next_0->users()[0];
311         next_i = next_i->users()[0];
312       }
313     }
314   }
315 }
316 
RewriteGraph()317 StatusOr<bool> ArCrsCombiner::RewriteGraph() {
318   if (all_reduce_map_.empty()) {
319     return false;
320   }
321   for (auto it : all_reduce_map_) {
322     auto pairs_vec = it.second;
323     for (auto pair : pairs_vec) {
324       auto all_reduce = pair.ar;
325       auto parent_computation = all_reduce->parent();
326       auto all_reduce_id = all_reduce->all_reduce_id();
327       auto prev = all_reduce->mutable_operand(0);
328       auto next = all_reduce->users()[0];
329       TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
330       TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
331       while (!next->IsCrossReplicaAllReduce()) {
332         switch (next->opcode()) {
333           case HloOpcode::kBitcast:
334           case HloOpcode::kTranspose:
335           case HloOpcode::kReshape:
336           case HloOpcode::kConvert:
337           case HloOpcode::kMultiply:
338             break;
339           case HloOpcode::kAdd:
340           case HloOpcode::kSubtract: {
341             auto other_operand = (next->operands()[0] == prev)
342                                      ? next->operands()[1]
343                                      : next->operands()[0];
344             // To move the AR past the addition/subtraction, we need to divide
345             // other_operand by the number of spatial partitions, except if
346             // other_operand is a cross-module AR, which can be eliminated.
347             if (other_operand->IsCrossModuleAllReduce() &&
348                 other_operand->user_count() == 1) {
349               TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
350                   other_operand->mutable_operand(0)));
351             } else {
352               auto shape = other_operand->shape();
353               Literal lit(shape);
354               lit.PopulateWithValue<float>(num_spatial_partitions_);
355               auto divisor = parent_computation->AddInstruction(
356                   HloInstruction::CreateConstant(lit.Clone()));
357               auto division = parent_computation->AddInstruction(
358                   HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
359                                                other_operand, divisor));
360               TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
361             }
362             break;
363           }
364           default:
365             LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
366         }
367         prev = next;
368         next = next->users()[0];
369       }
370       // The AllReduce and the CRS are combined to an all-core AllReduce.
371       next->set_all_reduce_id(all_reduce_id);
372     }
373   }
374   return true;
375 }
376 
Run(HloModule * module)377 StatusOr<bool> ArCrsCombiner::Run(HloModule* module) {
378   call_graph_ = CallGraph::Build(module);
379 
380   GroupAllReducesById(module);
381 
382   KeepProvablyEqualInstructionGroups();
383 
384   return RewriteGraph();
385 }
386 
387 }  // namespace xla
388