1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module_group_util.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <queue>
21 #include <stack>
22 #include <string>
23 #include <utility>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace xla {
40
GlobalPredecessors(HloInstruction * instruction)41 std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
42 HloInstruction* instruction) {
43 std::vector<HloInstruction*>
44 predecessors; // Use a vector to avoid non-determinism.
45 absl::flat_hash_set<HloInstruction*> unique;
46
47 // Adds to the unique predecessors list; if the predecessors is a companion
48 // instruction, also add companion instructions; if the predecessors is a
49 // cross-module all-reduce, also add the all-reduce instructions in the same
50 // group.
51 auto add_unique_predecessor = [&](HloInstruction* predecessor) {
52 if (unique.find(predecessor) != unique.end()) {
53 return;
54 }
55 if (metadata_.IsCompanionInstruction(predecessor)) {
56 for (HloInstruction* instr : metadata_.Companions(predecessor)) {
57 if (unique.insert(instr).second) {
58 predecessors.push_back(instr);
59 }
60 }
61 return;
62 }
63 if (predecessor->IsCrossModuleAllReduce()) {
64 for (HloInstruction* instr :
65 metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
66 if (unique.insert(instr).second) {
67 predecessors.push_back(instr);
68 }
69 }
70 return;
71 }
72 unique.insert(predecessor);
73 predecessors.push_back(predecessor);
74 };
75 // If the given instruction is a companion instruction, we need to find the
76 // predecessors of all of its companion instructions. If the instruction is an
77 // all-reduce, we need to find the predecessors of all the peer all-reduce
78 // instructions.
79 std::vector<HloInstruction*> instruction_group;
80 if (metadata_.IsCompanionInstruction(instruction)) {
81 for (HloInstruction* companion : metadata_.Companions(instruction)) {
82 instruction_group.push_back(companion);
83 }
84 } else if (instruction->IsCrossModuleAllReduce()) {
85 instruction_group =
86 metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
87 } else {
88 instruction_group.push_back(instruction);
89 }
90
91 for (HloInstruction* hlo : instruction_group) {
92 for (HloInstruction* operand : hlo->operands()) {
93 add_unique_predecessor(operand);
94 }
95 for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
96 add_unique_predecessor(control_predecessor);
97 }
98 }
99 if (instruction->opcode() == HloOpcode::kRecvDone &&
100 !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
101 // Send is a remote predecessor of RecvDone.
102 HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
103 add_unique_predecessor(send);
104 }
105 if (instruction->opcode() == HloOpcode::kSend &&
106 !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
107 // Recv is a remote predecessor of Send.
108 HloInstruction* recv_done =
109 metadata_.GetChannel(instruction->channel_id()).recv_done;
110 CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
111 CHECK_EQ(recv_done->operand_count(), 1);
112 HloInstruction* recv = recv_done->mutable_operand(0);
113 add_unique_predecessor(recv);
114 }
115 return predecessors;
116 }
117
GlobalSuccessors(HloInstruction * instruction)118 std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
119 HloInstruction* instruction) {
120 std::vector<HloInstruction*>
121 successors; // Use a vector to avoid non-determinism.
122 absl::flat_hash_set<HloInstruction*> unique;
123
124 // Adds to the unique successors list; if the successor is a companion
125 // instruction, also add companion instructions; if the successor is a
126 // cross-module all-reduce, also add the all-reduce instructions in the same
127 // group.
128 auto add_unique_successor = [&](HloInstruction* successor) {
129 if (unique.find(successor) != unique.end()) {
130 return;
131 }
132 if (metadata_.IsCompanionInstruction(successor)) {
133 for (HloInstruction* instr : metadata_.Companions(successor)) {
134 if (unique.insert(instr).second) {
135 successors.push_back(instr);
136 }
137 }
138 return;
139 }
140 if (successor->IsCrossModuleAllReduce()) {
141 for (HloInstruction* instr :
142 metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
143 if (unique.insert(instr).second) {
144 successors.push_back(instr);
145 }
146 }
147 return;
148 }
149 unique.insert(successor);
150 successors.push_back(successor);
151 };
152
153 // If the given instruction is a companion instruction, we need to find the
154 // successors of all of its companion instructions. If the instruction is an
155 // all-reduce, we need to find the successors of all its peer all-reduce
156 // instructions.
157 std::vector<HloInstruction*> instruction_group;
158 if (metadata_.IsCompanionInstruction(instruction)) {
159 for (HloInstruction* companion : metadata_.Companions(instruction)) {
160 instruction_group.push_back(companion);
161 }
162 } else if (instruction->IsCrossModuleAllReduce()) {
163 instruction_group =
164 metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
165 } else {
166 instruction_group.push_back(instruction);
167 }
168
169 for (HloInstruction* hlo : instruction_group) {
170 for (HloInstruction* user : hlo->users()) {
171 add_unique_successor(user);
172 }
173 for (HloInstruction* control_successor : hlo->control_successors()) {
174 add_unique_successor(control_successor);
175 }
176 }
177 if (instruction->opcode() == HloOpcode::kRecv &&
178 !DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
179 // Send is a remote successor of Recv.
180 const HloInstruction* recv_done = instruction->users().front();
181 CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
182 HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
183 add_unique_successor(send);
184 }
185 if (instruction->opcode() == HloOpcode::kSend &&
186 !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
187 // RecvDone is a remote successor of Send.
188 HloInstruction* recv_done =
189 metadata_.GetChannel(instruction->channel_id()).recv_done;
190 add_unique_successor(recv_done);
191 }
192 return successors;
193 }
194
RootInstructions(absl::Span<HloComputation * const> computations)195 std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
196 absl::Span<HloComputation* const> computations) {
197 std::vector<HloInstruction*> roots;
198 for (HloComputation* computation : computations) {
199 for (HloInstruction* instruction : computation->instructions()) {
200 if (GlobalSuccessors(instruction).empty()) {
201 // An instruction that has no successors, e.g., an unused instruction,
202 // is in roots, even though it's not the ROOT of its computation.
203 roots.push_back(instruction);
204 }
205 }
206 }
207 return roots;
208 }
209
CycleToString(HloInstruction * init_instruction)210 string HloModuleGroupUtil::CycleToString(HloInstruction* init_instruction) {
211 std::vector<string> names;
212 absl::flat_hash_set<HloInstruction*> seen;
213
214 std::function<bool(HloInstruction*)> helper =
215 [&](HloInstruction* instruction) {
216 if (seen.find(instruction) != seen.end()) {
217 if (instruction == init_instruction) {
218 names.push_back(instruction->name());
219 return true;
220 }
221 return false;
222 }
223 seen.insert(instruction);
224 for (HloInstruction* predecessor : GlobalPredecessors(instruction)) {
225 bool init_found = helper(predecessor);
226 if (init_found) {
227 names.push_back(instruction->name());
228 return true;
229 }
230 }
231 return false;
232 };
233
234 helper(init_instruction);
235 std::vector<string> pieces;
236 pieces.reserve(names.size());
237 for (auto name : names) {
238 pieces.push_back(name);
239 }
240 return absl::StrJoin(pieces, " --> ");
241 }
242
VisitTopologicalOrder(VisitStates * visit_state,const VisitFunction & visit_function,HloInstruction * root)243 Status HloModuleGroupUtil::VisitTopologicalOrder(
244 VisitStates* visit_state, const VisitFunction& visit_function,
245 HloInstruction* root) {
246 // Stack of HLO instructions visited in DFS order.
247 std::stack<HloInstruction*> stack;
248 stack.push(root);
249
250 while (!stack.empty()) {
251 HloInstruction* hlo = stack.top();
252
253 // Find the instruction group of the currently visited instruction. The
254 // instruction group represents all companion instructions of the current
255 // instruction, or all the all-reduce instructions that belong to the same
256 // group, or are considered to be a single entity for the purpose of the
257 // traversal (i.e., they must always be in the same visit state).
258 std::vector<HloInstruction*> instruction_group;
259 if (metadata_.IsCompanionInstruction(hlo)) {
260 for (HloInstruction* companion : metadata_.Companions(hlo)) {
261 instruction_group.push_back(companion);
262 }
263 } else if (hlo->IsCrossModuleAllReduce()) {
264 instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id());
265 } else {
266 instruction_group.push_back(hlo);
267 }
268
269 if ((*visit_state)[hlo] == VisitState::kVisited) {
270 // All instructions in the group must be in the same state.
271 for (HloInstruction* instruction : instruction_group) {
272 TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisited);
273 }
274 stack.pop();
275 continue;
276 }
277
278 if ((*visit_state)[hlo] == VisitState::kVisiting) {
279 TF_RETURN_IF_ERROR(visit_function(hlo, instruction_group));
280
281 // Set the visit state of all instructions in the group to kVisited.
282 for (HloInstruction* instruction : instruction_group) {
283 TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisiting);
284 (*visit_state)[instruction] = VisitState::kVisited;
285 }
286 stack.pop();
287 continue;
288 }
289
290 // Set the visit state of all instructions in the group to kVisiting.
291 for (HloInstruction* instruction : instruction_group) {
292 TF_RET_CHECK((*visit_state)[instruction] == VisitState::kNotVisited)
293 << instruction->ToString();
294 (*visit_state)[instruction] = VisitState::kVisiting;
295 }
296
297 // For each instruction in the group, visit its predecessors (operands,
298 // control predecessors and remote predecessors).
299 for (HloInstruction* instruction : instruction_group) {
300 for (HloInstruction* predecessor : GlobalPredecessors(instruction)) {
301 // Visiting a node that is already being visited implies that there is
302 // a cycle. Generate an error with the list of instructions in the
303 // cycle.
304 if ((*visit_state)[predecessor] == VisitState::kVisiting) {
305 return FailedPrecondition(
306 "Cross-computation cycle detected via communicating nodes.\n%s",
307 CycleToString(predecessor));
308 }
309 stack.push(predecessor);
310 }
311 }
312 }
313
314 return Status::OK();
315 }
316
VerifyComputations(absl::Span<HloComputation * const> computations)317 Status HloModuleGroupUtil::VerifyComputations(
318 absl::Span<HloComputation* const> computations) {
319 auto visit_function =
320 [&](HloInstruction* instruction,
321 const std::vector<HloInstruction*>& instruction_group) {
322 return Status::OK();
323 };
324 int64 instructions_count = 0;
325 VisitStates visit_states;
326 for (HloComputation* computation : computations) {
327 // Visit all instructions, and not just from the root instruction of the
328 // computation. This allows us to detect dead cycles (i.e., cycles that
329 // are not reachable from the root) or to enforce an order for the
330 // communication instructions that are not reachable from any roots.
331 for (HloInstruction* instruction : computation->instructions()) {
332 TF_RETURN_IF_ERROR(
333 VisitTopologicalOrder(&visit_states, visit_function, instruction));
334 }
335 instructions_count += computation->instruction_count();
336 }
337
338 // Check if all instructions are visited and are in the visited state.
339 TF_RET_CHECK(visit_states.size() == instructions_count);
340 for (auto& state : visit_states) {
341 TF_RET_CHECK(state.second == VisitState::kVisited);
342 }
343
344 return Status::OK();
345 }
346
347 StatusOr<std::unique_ptr<HloReachabilityMap>>
ComputeReachability(absl::Span<HloComputation * const> computations)348 HloModuleGroupUtil::ComputeReachability(
349 absl::Span<HloComputation* const> computations) {
350 std::vector<HloInstruction*> post_order;
351 auto visit_function =
352 [&](HloInstruction* instruction,
353 const std::vector<HloInstruction*>& instruction_group) {
354 post_order.insert(post_order.end(), instruction_group.begin(),
355 instruction_group.end());
356 return Status::OK();
357 };
358 HloModuleGroupUtil::VisitStates visit_states;
359 for (HloInstruction* root : RootInstructions(computations)) {
360 TF_RETURN_IF_ERROR(
361 VisitTopologicalOrder(&visit_states, visit_function, root));
362 }
363 auto reachability = absl::make_unique<HloReachabilityMap>(post_order);
364 for (HloInstruction* hlo : post_order) {
365 reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
366 }
367 return std::move(reachability);
368 }
369
UpdateReachabilityThroughInstruction(HloInstruction * instruction,HloReachabilityMap * reachability_map)370 void HloModuleGroupUtil::UpdateReachabilityThroughInstruction(
371 HloInstruction* instruction, HloReachabilityMap* reachability_map) {
372 std::queue<HloInstruction*> worklist;
373 worklist.push(instruction);
374
375 while (!worklist.empty()) {
376 HloInstruction* item = worklist.front();
377 worklist.pop();
378 if (reachability_map->SetReachabilityToUnion(GlobalPredecessors(item),
379 item)) {
380 for (HloInstruction* successor : GlobalSuccessors(item)) {
381 worklist.push(successor);
382 }
383 }
384 }
385 }
386
387 } // namespace xla
388