1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
17
18 #include <sstream>
19 #include <string>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace xla {
34
ToString() const35 string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
36 string repr =
37 (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL";
38 switch (kind_) {
39 case ComputationKind::kInvalid:
40 repr += ":INVALID";
41 break;
42 case ComputationKind::kWhileCondition:
43 repr += ":WHILE_CONDITION";
44 break;
45 case ComputationKind::kWhileBody:
46 repr += ":WHILE_BODY";
47 break;
48 case ComputationKind::kConditionalBranch:
49 repr += absl::StrCat(":CONDITIONAL_BRANCH_", index_);
50 break;
51 case ComputationKind::kCallFunction:
52 repr += ":CALL";
53 break;
54 }
55 return repr;
56 }
57
58 /* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
Build(absl::Span<HloModule * const> modules)59 HloModuleGroupMetadata::Build(absl::Span<HloModule* const> modules) {
60 auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
61 TF_RETURN_IF_ERROR(metadata->Build());
62 return std::move(metadata);
63 }
64
Build()65 Status HloModuleGroupMetadata::Build() {
66 TF_RETURN_IF_ERROR(RecordInstructions());
67 TF_RETURN_IF_ERROR(VerifyChannelInstructions());
68
69 // Record all companion while instructions.
70 const auto visitor = [this](HloInstruction* hlo) -> Status {
71 // We only need to process if the instruction is within the computation
72 // of a companion instruction, like in the condition or body computation
73 // of a While.
74 const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent());
75 if (tracked == nullptr) {
76 return Status::OK();
77 }
78
79 if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) {
80 std::vector<HloComputation*> peers;
81 if (IsChannelInstruction(hlo)) {
82 peers.push_back(PeerComputation(hlo));
83 } else if (hlo->IsCrossModuleAllReduce()) {
84 for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) {
85 if (instr == hlo) {
86 continue;
87 }
88 peers.push_back(instr->parent());
89 }
90 }
91
92 // Add the parent computation of this channel (or all-reduce) instruction
93 // and its peer computation(s) (both must be while computations) as
94 // companions.
95 for (HloComputation* peer_computation : peers) {
96 const TrackedInstruction* peer_tracked =
97 GetTrackedInstruction(peer_computation);
98 TF_RET_CHECK(peer_tracked != nullptr)
99 << "Peer instruction is not a possible companion";
100 TF_RET_CHECK(*tracked == *peer_tracked)
101 << "Peer instruction does not match the computation kind";
102 TF_RETURN_IF_ERROR(
103 AddCompanion(tracked->instruction(), peer_tracked->instruction()));
104 tracked_instructions_comms_[tracked->instruction()].push_back(hlo);
105 }
106 } else if (IsCompanionInstruction(hlo)) {
107 // Add the parents of companion instructions (they must be all of the same
108 // kind of instructions, opcode wise) as companions.
109 for (HloInstruction* companion : Companions(hlo)) {
110 const TrackedInstruction* companion_tracked =
111 GetTrackedInstruction(companion->parent());
112 TF_RET_CHECK(companion_tracked != nullptr);
113 TF_RET_CHECK(*tracked == *companion_tracked);
114 TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(),
115 companion_tracked->instruction()));
116 }
117 }
118
119 return Status::OK();
120 };
121
122 // Visit the computations in postorder so that the companion information grows
123 // from inner computations to outer ones.
124 for (HloModule* module : modules_) {
125 for (HloComputation* computation : module->MakeComputationPostOrder()) {
126 TF_RETURN_IF_ERROR(computation->Accept(visitor));
127 }
128 }
129 TF_RETURN_IF_ERROR(VerifyCompanionSets());
130 if (VLOG_IS_ON(4)) {
131 DumpCollectedStats();
132 }
133
134 for (HloModule* module : modules_) {
135 TF_ASSIGN_OR_RETURN(
136 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
137 TuplePointsToAnalysis::Run(module));
138 points_to_analyses_[module] = std::move(points_to_analysis);
139 }
140
141 return Status::OK();
142 }
143
VerifyCompanionSets() const144 Status HloModuleGroupMetadata::VerifyCompanionSets() const {
145 for (const auto& companions : companion_sets_) {
146 // A companion set must be composed at most of an instruction per
147 // device/module.
148 std::unordered_set<int64> devices;
149 for (HloInstruction* instruction : *companions) {
150 // Go through all the communicating instructions (send, recv) of the given
151 // companion, and record their device.
152 auto it = tracked_instructions_comms_.find(instruction);
153 if (it == tracked_instructions_comms_.end()) {
154 // Companions can be added even if they have no communicating
155 // instructions, if they are parent of companions.
156 continue;
157 }
158 std::unordered_set<int64> comm_devices;
159 for (HloInstruction* comm_instruction : it->second) {
160 auto device = GetInstructionDevice(*comm_instruction);
161 TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
162 << " does not have a device";
163 comm_devices.insert(*device);
164 }
165 for (int64 device : comm_devices) {
166 if (!devices.insert(device).second) {
167 std::stringstream ss;
168 ss << "Companion set:" << std::endl;
169 for (HloInstruction* hlo : *companions) {
170 ss << " " << hlo->name() << std::endl;
171 }
172 ss << "has multiple instructions on the same device";
173 return FailedPrecondition("%s", ss.str());
174 }
175 }
176 }
177 }
178 return Status::OK();
179 }
180
IsChannelInstruction(const HloInstruction * instruction) const181 bool HloModuleGroupMetadata::IsChannelInstruction(
182 const HloInstruction* instruction) const {
183 switch (instruction->opcode()) {
184 case HloOpcode::kSend:
185 case HloOpcode::kRecv:
186 case HloOpcode::kSendDone:
187 case HloOpcode::kRecvDone: {
188 const HloSendRecvInstruction* send_recv_instr =
189 DynCast<HloSendRecvInstruction>(instruction);
190 CHECK(send_recv_instr != nullptr);
191 return !send_recv_instr->is_host_transfer();
192 }
193 default:
194 return false;
195 }
196 }
197
IsCompanionInstruction(HloInstruction * hlo) const198 bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
199 return companion_set_index_.contains(hlo);
200 }
201
InstructionCommunicates(HloInstruction * hlo) const202 bool HloModuleGroupMetadata::InstructionCommunicates(
203 HloInstruction* hlo) const {
204 return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) ||
205 hlo->IsCrossModuleAllReduce();
206 }
207
GetChannel(int64 channel_id) const208 const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
209 int64 channel_id) const {
210 CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end());
211 return channels_[channel_id_map_.at(channel_id)];
212 }
213
HasChannel(int64 channel_id) const214 bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
215 return channel_id_map_.find(channel_id) != channel_id_map_.end();
216 }
217
PeerComputation(const HloInstruction * instruction) const218 HloComputation* HloModuleGroupMetadata::PeerComputation(
219 const HloInstruction* instruction) const {
220 CHECK(IsChannelInstruction(instruction));
221 const Channel& channel = GetChannel(instruction->channel_id());
222 switch (instruction->opcode()) {
223 case HloOpcode::kSend:
224 case HloOpcode::kSendDone:
225 return channel.recv->parent();
226 case HloOpcode::kRecv:
227 case HloOpcode::kRecvDone:
228 return channel.send->parent();
229 default:
230 LOG(FATAL) << "opcode not supported";
231 }
232 }
233
GetAllReduceGroup(int64 all_reduce_id) const234 const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
235 int64 all_reduce_id) const {
236 auto it = all_reduce_map_.find(all_reduce_id);
237 CHECK(it != all_reduce_map_.end());
238 return it->second;
239 }
240
241 std::vector<HloModuleGroupMetadata::TrackedInstruction>
GetCompanionsPath(const HloInstruction * hlo) const242 HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const {
243 std::vector<TrackedInstruction> path;
244 const HloComputation* parent = hlo->parent();
245 const TrackedInstruction* companion;
246 while ((companion = GetTrackedInstruction(parent)) != nullptr) {
247 parent = companion->instruction()->parent();
248 path.push_back(*companion);
249 }
250 return path;
251 }
252
CheckCompanionPathsCompatibility(const std::vector<TrackedInstruction> & path0,const std::vector<TrackedInstruction> & path1) const253 bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility(
254 const std::vector<TrackedInstruction>& path0,
255 const std::vector<TrackedInstruction>& path1) const {
256 if (path0.size() != path1.size()) {
257 VLOG(5) << "Companion path size do not match: " << path0.size()
258 << " != " << path1.size();
259 return false;
260 }
261 for (int64 i = 0; i < path0.size(); ++i) {
262 if (path0[i] != path1[i]) {
263 VLOG(5) << "Companion instructions at path index " << i
264 << " do not have the same opcode: " << path0[i].ToString()
265 << " vs " << path1[i].ToString();
266 return false;
267 }
268 }
269 return true;
270 }
271
GetModuleId(const HloModule * module) const272 int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
273 for (int64 i = 0; i < modules_.size(); ++i) {
274 if (modules_[i] == module) {
275 return i;
276 }
277 }
278 LOG(FATAL) << "unknown module";
279 }
280
GetInstructionDevice(const HloInstruction & instruction) const281 absl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
282 const HloInstruction& instruction) const {
283 // The module group metadata can be created in both "single module, multiple
284 // devices" and "multiple modules, no explicit devices" fashions.
285 // The API returns an optional even though the current implementation always
286 // returns a device, to account for cases where we cannot guess a device.
287 // In such cases the VerifyChannelInstructions() will return proper errors.
288 absl::optional<int64> device = instruction.sharding_unique_device();
289 if (!device) {
290 device = GetModuleId(instruction.parent()->parent());
291 }
292 return device;
293 }
294
GetDeviceModulesCount() const295 int64 HloModuleGroupMetadata::GetDeviceModulesCount() const {
296 return modules_.size();
297 }
298
RecordInstructions()299 Status HloModuleGroupMetadata::RecordInstructions() {
300 const auto visitor = [this](HloInstruction* hlo) -> Status {
301 if (hlo->opcode() == HloOpcode::kWhile) {
302 tracked_instructions_[hlo->while_condition()] =
303 TrackedInstruction(hlo, ComputationKind::kWhileCondition);
304 tracked_instructions_[hlo->while_body()] =
305 TrackedInstruction(hlo, ComputationKind::kWhileBody);
306 } else if (hlo->opcode() == HloOpcode::kConditional) {
307 for (int b = 0; b < hlo->branch_count(); ++b) {
308 tracked_instructions_[hlo->branch_computation(b)] =
309 TrackedInstruction(hlo, ComputationKind::kConditionalBranch, b);
310 }
311 } else if (hlo->opcode() == HloOpcode::kCall) {
312 tracked_instructions_[hlo->to_apply()] =
313 TrackedInstruction(hlo, ComputationKind::kCallFunction);
314 }
315
316 // Group cross module all-reduce instructions by the all_reduce id.
317 if (hlo->IsCrossModuleAllReduce()) {
318 TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) ==
319 channel_id_map_.end())
320 << "all_reduce_id " << *hlo->all_reduce_id()
321 << " is already used by a send/recv instruction";
322 all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo);
323 max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id());
324 return Status::OK();
325 }
326
327 if (!IsChannelInstruction(hlo)) {
328 return Status::OK();
329 }
330
331 TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) ==
332 all_reduce_map_.end())
333 << "channel id " << hlo->channel_id()
334 << " is already used by an all-reduce instruction";
335
336 // Add a new channel if needed.
337 if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
338 channels_.emplace_back();
339 channels_.back().id = hlo->channel_id();
340 channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
341 max_channel_id_ = std::max(max_channel_id_, hlo->channel_id());
342 }
343 Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
344
345 if (hlo->opcode() == HloOpcode::kSend) {
346 TF_RET_CHECK(channel.send == nullptr)
347 << "channel id " << hlo->channel_id()
348 << " is used by multiple send instructions";
349 channel.send = hlo;
350 }
351 if (hlo->opcode() == HloOpcode::kRecv) {
352 TF_RET_CHECK(channel.recv == nullptr)
353 << "channel id " << hlo->channel_id()
354 << " is used by multiple recv instructions";
355 channel.recv = hlo;
356 }
357 if (hlo->opcode() == HloOpcode::kSendDone) {
358 TF_RET_CHECK(channel.send_done == nullptr)
359 << "channel id " << hlo->channel_id()
360 << " is used by multiple send-done instructions";
361 channel.send_done = hlo;
362 }
363 if (hlo->opcode() == HloOpcode::kRecvDone) {
364 TF_RET_CHECK(channel.recv_done == nullptr)
365 << "channel id " << hlo->channel_id()
366 << " is used by multiple recv-done instructions";
367 channel.recv_done = hlo;
368 }
369 return Status::OK();
370 };
371
372 for (HloModule* module : modules_) {
373 for (auto* computation : module->computations()) {
374 TF_RETURN_IF_ERROR(computation->Accept(visitor));
375 }
376 }
377 VLOG(2) << "Created " << channels_.size() << " channels";
378 VLOG(2) << "Created " << all_reduce_map_.size() << " all-reduce groups";
379 return Status::OK();
380 }
381
AddCompanion(HloInstruction * instruction1,HloInstruction * instruction2)382 Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
383 HloInstruction* instruction2) {
384 TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
385 instruction1->opcode() == HloOpcode::kConditional ||
386 instruction1->opcode() == HloOpcode::kCall);
387 VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
388 << instruction2->ToString();
389 if (instruction1 == instruction2) {
390 return Status::OK();
391 } else if (!ContainsKey(companion_set_index_, instruction1) &&
392 !ContainsKey(companion_set_index_, instruction2)) {
393 companion_sets_.push_back(
394 absl::make_unique<std::vector<HloInstruction*>>());
395 auto companion_set = companion_sets_.back().get();
396 companion_set->push_back(instruction1);
397 companion_set->push_back(instruction2);
398 companion_set_index_[instruction1] = companion_sets_.size() - 1;
399 companion_set_index_[instruction2] = companion_sets_.size() - 1;
400 } else if (!ContainsKey(companion_set_index_, instruction1)) {
401 companion_sets_[companion_set_index_[instruction2]]->push_back(
402 instruction1);
403 companion_set_index_[instruction1] = companion_set_index_[instruction2];
404 } else if (!ContainsKey(companion_set_index_, instruction2)) {
405 companion_sets_[companion_set_index_[instruction1]]->push_back(
406 instruction2);
407 companion_set_index_[instruction2] = companion_set_index_[instruction1];
408 } else if (companion_set_index_[instruction1] !=
409 companion_set_index_[instruction2]) {
410 // At any point while building the companion sets, each instruction belongs
411 // to at most 1 companion set, so the union of two companion sets is
412 // concatenating two disjoint sets.
413 absl::c_copy(Companions(instruction2),
414 std::back_inserter(
415 *companion_sets_[companion_set_index_[instruction1]]));
416 int64 index_to_remove = companion_set_index_[instruction2];
417 for (HloInstruction* hlo : Companions(instruction2)) {
418 companion_set_index_[hlo] = companion_set_index_[instruction1];
419 }
420 // We can't remove the set from the vector because companion_set_index_
421 // references sets by their index in this vector, so we reset to nullptr
422 // instead.
423 companion_sets_[index_to_remove].reset(nullptr);
424 }
425 return Status::OK();
426 }
427
VerifyChannelInstructions()428 Status HloModuleGroupMetadata::VerifyChannelInstructions() {
429 for (const Channel& channel : channels_) {
430 if (channel.send == nullptr) {
431 return FailedPrecondition("missing send for id : %d", channel.id);
432 }
433 if (channel.recv == nullptr) {
434 return FailedPrecondition("missing recv for id : %d", channel.id);
435 }
436 if (channel.send_done == nullptr) {
437 return FailedPrecondition("missing send-done for id : %d", channel.id);
438 }
439 if (channel.recv_done == nullptr) {
440 return FailedPrecondition("missing recv-done for id : %d", channel.id);
441 }
442 }
443
444 // Check if the shapes match for each channel.
445 for (const Channel& channel : channels_) {
446 const Shape& send_shape = channel.send->operand(0)->shape();
447 const Shape& recv_shape =
448 ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0);
449 if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
450 return FailedPrecondition("send/recv shapes do not match");
451 }
452 auto send_device = GetInstructionDevice(*channel.send);
453 auto send_done_device = GetInstructionDevice(*channel.send_done);
454 if (!send_device) {
455 return FailedPrecondition("send instruction must have a device: %s",
456 channel.send->ToString());
457 }
458 if (!send_done_device) {
459 return FailedPrecondition("send_done instruction must have a device: %s",
460 channel.send_done->ToString());
461 }
462 if (*send_device != *send_done_device) {
463 return FailedPrecondition(
464 "send and send-done (channel=%d) must be on the same device: %d "
465 "vs. %d",
466 channel.id, *send_device, *send_done_device);
467 }
468 auto recv_device = GetInstructionDevice(*channel.recv);
469 auto recv_done_device = GetInstructionDevice(*channel.recv_done);
470 if (!recv_done_device) {
471 return FailedPrecondition("recv_done instruction must have a device: %s",
472 channel.recv_done->ToString());
473 }
474 if (*recv_device != *recv_done_device) {
475 return FailedPrecondition(
476 "recv and recv-done (channel=%d) must be on the same device: %d "
477 "vs. %d",
478 channel.id, *recv_device, *recv_done_device);
479 }
480 if (*send_device == *recv_device) {
481 return FailedPrecondition(
482 "send and recv (channel=%d) must be on different devices: %d",
483 channel.id, *send_device);
484 }
485 }
486
487 for (const Channel& channel : channels_) {
488 TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send));
489 TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done));
490 TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv));
491 TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done));
492 }
493 // Check if the nest levels match for each channel.
494 for (const Channel& channel : channels_) {
495 std::vector<TrackedInstruction> path = GetCompanionsPath(channel.send);
496 if (!CheckCompanionPathsCompatibility(
497 path, GetCompanionsPath(channel.send_done)) ||
498 !CheckCompanionPathsCompatibility(path,
499 GetCompanionsPath(channel.recv)) ||
500 !CheckCompanionPathsCompatibility(
501 path, GetCompanionsPath(channel.recv_done))) {
502 return FailedPrecondition(
503 "Nest companion paths do not match for channel %d", channel.id);
504 }
505 }
506 return Status::OK();
507 }
508
CheckCommunicatingInstruction(HloInstruction * instruction) const509 Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
510 HloInstruction* instruction) const {
511 HloComputation* computation = instruction->parent();
512 const HloModule* module = computation->parent();
513 if (module->entry_computation() == computation ||
514 tracked_instructions_.contains(computation)) {
515 return Status::OK();
516 }
517 return FailedPrecondition("channel is used in disallowed computation");
518 }
519
DumpCollectedStats() const520 void HloModuleGroupMetadata::DumpCollectedStats() const {
521 std::map<std::pair<int64, int64>, int64> communication_histogram;
522 for (auto& channel : channels_) {
523 auto from_device = GetInstructionDevice(*channel.send);
524 auto to_device = GetInstructionDevice(*channel.recv);
525 LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device
526 << " to_device=" << *to_device << " send=" << channel.send->name()
527 << " send_done=" << channel.send_done->name()
528 << " recv=" << channel.recv->name()
529 << " recv_done=" << channel.recv_done->name();
530 communication_histogram[std::pair<int64, int64>(*from_device,
531 *to_device)] += 1;
532 }
533 for (auto& fromto_count : communication_histogram) {
534 LOG(INFO) << "From " << fromto_count.first.first << " to "
535 << fromto_count.first.second << ": " << fromto_count.second;
536 }
537 for (auto& companion_set : companion_sets_) {
538 LOG(INFO) << "Companion set:";
539 for (HloInstruction* instruction : *companion_set) {
540 LOG(INFO) << " " << instruction->name();
541 }
542 }
543 for (auto& instruction_comm : tracked_instructions_comms_) {
544 LOG(INFO) << "Communicating instruction " << instruction_comm.first->name();
545 for (HloInstruction* instruction : instruction_comm.second) {
546 auto device = GetInstructionDevice(*instruction);
547 LOG(INFO) << " " << instruction->name() << " on device " << *device;
548 }
549 }
550 }
551
552 } // namespace xla
553