1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
17 
18 #include <sstream>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
ToString() const35 string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
36   string repr =
37       (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL";
38   switch (kind_) {
39     case ComputationKind::kInvalid:
40       repr += ":INVALID";
41       break;
42     case ComputationKind::kWhileCondition:
43       repr += ":WHILE_CONDITION";
44       break;
45     case ComputationKind::kWhileBody:
46       repr += ":WHILE_BODY";
47       break;
48     case ComputationKind::kConditionalBranch:
49       repr += absl::StrCat(":CONDITIONAL_BRANCH_", index_);
50       break;
51     case ComputationKind::kCallFunction:
52       repr += ":CALL";
53       break;
54   }
55   return repr;
56 }
57 
58 /* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
Build(absl::Span<HloModule * const> modules)59 HloModuleGroupMetadata::Build(absl::Span<HloModule* const> modules) {
60   auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
61   TF_RETURN_IF_ERROR(metadata->Build());
62   return std::move(metadata);
63 }
64 
Build()65 Status HloModuleGroupMetadata::Build() {
66   TF_RETURN_IF_ERROR(RecordInstructions());
67   TF_RETURN_IF_ERROR(VerifyChannelInstructions());
68 
69   // Record all companion while instructions.
70   const auto visitor = [this](HloInstruction* hlo) -> Status {
71     // We only need to process if the instruction is within the computation
72     // of a companion instruction, like in the condition or body computation
73     // of a While.
74     const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent());
75     if (tracked == nullptr) {
76       return Status::OK();
77     }
78 
79     if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) {
80       std::vector<HloComputation*> peers;
81       if (IsChannelInstruction(hlo)) {
82         peers.push_back(PeerComputation(hlo));
83       } else if (hlo->IsCrossModuleAllReduce()) {
84         for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) {
85           if (instr == hlo) {
86             continue;
87           }
88           peers.push_back(instr->parent());
89         }
90       }
91 
92       // Add the parent computation of this channel (or all-reduce) instruction
93       // and its peer computation(s) (both must be while computations) as
94       // companions.
95       for (HloComputation* peer_computation : peers) {
96         const TrackedInstruction* peer_tracked =
97             GetTrackedInstruction(peer_computation);
98         TF_RET_CHECK(peer_tracked != nullptr)
99             << "Peer instruction is not a possible companion";
100         TF_RET_CHECK(*tracked == *peer_tracked)
101             << "Peer instruction does not match the computation kind";
102         TF_RETURN_IF_ERROR(
103             AddCompanion(tracked->instruction(), peer_tracked->instruction()));
104         tracked_instructions_comms_[tracked->instruction()].push_back(hlo);
105       }
106     } else if (IsCompanionInstruction(hlo)) {
107       // Add the parents of companion instructions (they must be all of the same
108       // kind of instructions, opcode wise) as companions.
109       for (HloInstruction* companion : Companions(hlo)) {
110         const TrackedInstruction* companion_tracked =
111             GetTrackedInstruction(companion->parent());
112         TF_RET_CHECK(companion_tracked != nullptr);
113         TF_RET_CHECK(*tracked == *companion_tracked);
114         TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(),
115                                         companion_tracked->instruction()));
116       }
117     }
118 
119     return Status::OK();
120   };
121 
122   // Visit the computations in postorder so that the companion information grows
123   // from inner computations to outer ones.
124   for (HloModule* module : modules_) {
125     for (HloComputation* computation : module->MakeComputationPostOrder()) {
126       TF_RETURN_IF_ERROR(computation->Accept(visitor));
127     }
128   }
129   TF_RETURN_IF_ERROR(VerifyCompanionSets());
130   if (VLOG_IS_ON(4)) {
131     DumpCollectedStats();
132   }
133 
134   for (HloModule* module : modules_) {
135     TF_ASSIGN_OR_RETURN(
136         std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
137         TuplePointsToAnalysis::Run(module));
138     points_to_analyses_[module] = std::move(points_to_analysis);
139   }
140 
141   return Status::OK();
142 }
143 
VerifyCompanionSets() const144 Status HloModuleGroupMetadata::VerifyCompanionSets() const {
145   for (const auto& companions : companion_sets_) {
146     // A companion set must be composed at most of an instruction per
147     // device/module.
148     std::unordered_set<int64> devices;
149     for (HloInstruction* instruction : *companions) {
150       // Go through all the communicating instructions (send, recv) of the given
151       // companion, and record their device.
152       auto it = tracked_instructions_comms_.find(instruction);
153       if (it == tracked_instructions_comms_.end()) {
154         // Companions can be added even if they have no communicating
155         // instructions, if they are parent of companions.
156         continue;
157       }
158       std::unordered_set<int64> comm_devices;
159       for (HloInstruction* comm_instruction : it->second) {
160         auto device = GetInstructionDevice(*comm_instruction);
161         TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
162                              << " does not have a device";
163         comm_devices.insert(*device);
164       }
165       for (int64 device : comm_devices) {
166         if (!devices.insert(device).second) {
167           std::stringstream ss;
168           ss << "Companion set:" << std::endl;
169           for (HloInstruction* hlo : *companions) {
170             ss << "  " << hlo->name() << std::endl;
171           }
172           ss << "has multiple instructions on the same device";
173           return FailedPrecondition("%s", ss.str());
174         }
175       }
176     }
177   }
178   return Status::OK();
179 }
180 
IsChannelInstruction(const HloInstruction * instruction) const181 bool HloModuleGroupMetadata::IsChannelInstruction(
182     const HloInstruction* instruction) const {
183   switch (instruction->opcode()) {
184     case HloOpcode::kSend:
185     case HloOpcode::kRecv:
186     case HloOpcode::kSendDone:
187     case HloOpcode::kRecvDone: {
188       const HloSendRecvInstruction* send_recv_instr =
189           DynCast<HloSendRecvInstruction>(instruction);
190       CHECK(send_recv_instr != nullptr);
191       return !send_recv_instr->is_host_transfer();
192     }
193     default:
194       return false;
195   }
196 }
197 
IsCompanionInstruction(HloInstruction * hlo) const198 bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
199   return companion_set_index_.contains(hlo);
200 }
201 
InstructionCommunicates(HloInstruction * hlo) const202 bool HloModuleGroupMetadata::InstructionCommunicates(
203     HloInstruction* hlo) const {
204   return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) ||
205          hlo->IsCrossModuleAllReduce();
206 }
207 
GetChannel(int64 channel_id) const208 const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
209     int64 channel_id) const {
210   CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end());
211   return channels_[channel_id_map_.at(channel_id)];
212 }
213 
HasChannel(int64 channel_id) const214 bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
215   return channel_id_map_.find(channel_id) != channel_id_map_.end();
216 }
217 
PeerComputation(const HloInstruction * instruction) const218 HloComputation* HloModuleGroupMetadata::PeerComputation(
219     const HloInstruction* instruction) const {
220   CHECK(IsChannelInstruction(instruction));
221   const Channel& channel = GetChannel(instruction->channel_id());
222   switch (instruction->opcode()) {
223     case HloOpcode::kSend:
224     case HloOpcode::kSendDone:
225       return channel.recv->parent();
226     case HloOpcode::kRecv:
227     case HloOpcode::kRecvDone:
228       return channel.send->parent();
229     default:
230       LOG(FATAL) << "opcode not supported";
231   }
232 }
233 
GetAllReduceGroup(int64 all_reduce_id) const234 const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
235     int64 all_reduce_id) const {
236   auto it = all_reduce_map_.find(all_reduce_id);
237   CHECK(it != all_reduce_map_.end());
238   return it->second;
239 }
240 
241 std::vector<HloModuleGroupMetadata::TrackedInstruction>
GetCompanionsPath(const HloInstruction * hlo) const242 HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const {
243   std::vector<TrackedInstruction> path;
244   const HloComputation* parent = hlo->parent();
245   const TrackedInstruction* companion;
246   while ((companion = GetTrackedInstruction(parent)) != nullptr) {
247     parent = companion->instruction()->parent();
248     path.push_back(*companion);
249   }
250   return path;
251 }
252 
CheckCompanionPathsCompatibility(const std::vector<TrackedInstruction> & path0,const std::vector<TrackedInstruction> & path1) const253 bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility(
254     const std::vector<TrackedInstruction>& path0,
255     const std::vector<TrackedInstruction>& path1) const {
256   if (path0.size() != path1.size()) {
257     VLOG(5) << "Companion path size do not match: " << path0.size()
258             << " != " << path1.size();
259     return false;
260   }
261   for (int64 i = 0; i < path0.size(); ++i) {
262     if (path0[i] != path1[i]) {
263       VLOG(5) << "Companion instructions at path index " << i
264               << " do not have the same opcode: " << path0[i].ToString()
265               << " vs " << path1[i].ToString();
266       return false;
267     }
268   }
269   return true;
270 }
271 
GetModuleId(const HloModule * module) const272 int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
273   for (int64 i = 0; i < modules_.size(); ++i) {
274     if (modules_[i] == module) {
275       return i;
276     }
277   }
278   LOG(FATAL) << "unknown module";
279 }
280 
GetInstructionDevice(const HloInstruction & instruction) const281 absl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
282     const HloInstruction& instruction) const {
283   // The module group metadata can be created in both "single module, multiple
284   // devices" and "multiple modules, no explicit devices" fashions.
285   // The API returns an optional even though the current implementation always
286   // returns a device, to account for cases where we cannot guess a device.
287   // In such cases the VerifyChannelInstructions() will return proper errors.
288   absl::optional<int64> device = instruction.sharding_unique_device();
289   if (!device) {
290     device = GetModuleId(instruction.parent()->parent());
291   }
292   return device;
293 }
294 
GetDeviceModulesCount() const295 int64 HloModuleGroupMetadata::GetDeviceModulesCount() const {
296   return modules_.size();
297 }
298 
RecordInstructions()299 Status HloModuleGroupMetadata::RecordInstructions() {
300   const auto visitor = [this](HloInstruction* hlo) -> Status {
301     if (hlo->opcode() == HloOpcode::kWhile) {
302       tracked_instructions_[hlo->while_condition()] =
303           TrackedInstruction(hlo, ComputationKind::kWhileCondition);
304       tracked_instructions_[hlo->while_body()] =
305           TrackedInstruction(hlo, ComputationKind::kWhileBody);
306     } else if (hlo->opcode() == HloOpcode::kConditional) {
307       for (int b = 0; b < hlo->branch_count(); ++b) {
308         tracked_instructions_[hlo->branch_computation(b)] =
309             TrackedInstruction(hlo, ComputationKind::kConditionalBranch, b);
310       }
311     } else if (hlo->opcode() == HloOpcode::kCall) {
312       tracked_instructions_[hlo->to_apply()] =
313           TrackedInstruction(hlo, ComputationKind::kCallFunction);
314     }
315 
316     // Group cross module all-reduce instructions by the all_reduce id.
317     if (hlo->IsCrossModuleAllReduce()) {
318       TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) ==
319                    channel_id_map_.end())
320           << "all_reduce_id " << *hlo->all_reduce_id()
321           << " is already used by a send/recv instruction";
322       all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo);
323       max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id());
324       return Status::OK();
325     }
326 
327     if (!IsChannelInstruction(hlo)) {
328       return Status::OK();
329     }
330 
331     TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) ==
332                  all_reduce_map_.end())
333         << "channel id " << hlo->channel_id()
334         << " is already used by an all-reduce instruction";
335 
336     // Add a new channel if needed.
337     if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
338       channels_.emplace_back();
339       channels_.back().id = hlo->channel_id();
340       channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
341       max_channel_id_ = std::max(max_channel_id_, hlo->channel_id());
342     }
343     Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
344 
345     if (hlo->opcode() == HloOpcode::kSend) {
346       TF_RET_CHECK(channel.send == nullptr)
347           << "channel id " << hlo->channel_id()
348           << " is used by multiple send instructions";
349       channel.send = hlo;
350     }
351     if (hlo->opcode() == HloOpcode::kRecv) {
352       TF_RET_CHECK(channel.recv == nullptr)
353           << "channel id " << hlo->channel_id()
354           << " is used by multiple recv instructions";
355       channel.recv = hlo;
356     }
357     if (hlo->opcode() == HloOpcode::kSendDone) {
358       TF_RET_CHECK(channel.send_done == nullptr)
359           << "channel id " << hlo->channel_id()
360           << " is used by multiple send-done instructions";
361       channel.send_done = hlo;
362     }
363     if (hlo->opcode() == HloOpcode::kRecvDone) {
364       TF_RET_CHECK(channel.recv_done == nullptr)
365           << "channel id " << hlo->channel_id()
366           << " is used by multiple recv-done instructions";
367       channel.recv_done = hlo;
368     }
369     return Status::OK();
370   };
371 
372   for (HloModule* module : modules_) {
373     for (auto* computation : module->computations()) {
374       TF_RETURN_IF_ERROR(computation->Accept(visitor));
375     }
376   }
377   VLOG(2) << "Created " << channels_.size() << " channels";
378   VLOG(2) << "Created " << all_reduce_map_.size() << " all-reduce groups";
379   return Status::OK();
380 }
381 
AddCompanion(HloInstruction * instruction1,HloInstruction * instruction2)382 Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
383                                             HloInstruction* instruction2) {
384   TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
385                instruction1->opcode() == HloOpcode::kConditional ||
386                instruction1->opcode() == HloOpcode::kCall);
387   VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
388           << instruction2->ToString();
389   if (instruction1 == instruction2) {
390     return Status::OK();
391   } else if (!ContainsKey(companion_set_index_, instruction1) &&
392              !ContainsKey(companion_set_index_, instruction2)) {
393     companion_sets_.push_back(
394         absl::make_unique<std::vector<HloInstruction*>>());
395     auto companion_set = companion_sets_.back().get();
396     companion_set->push_back(instruction1);
397     companion_set->push_back(instruction2);
398     companion_set_index_[instruction1] = companion_sets_.size() - 1;
399     companion_set_index_[instruction2] = companion_sets_.size() - 1;
400   } else if (!ContainsKey(companion_set_index_, instruction1)) {
401     companion_sets_[companion_set_index_[instruction2]]->push_back(
402         instruction1);
403     companion_set_index_[instruction1] = companion_set_index_[instruction2];
404   } else if (!ContainsKey(companion_set_index_, instruction2)) {
405     companion_sets_[companion_set_index_[instruction1]]->push_back(
406         instruction2);
407     companion_set_index_[instruction2] = companion_set_index_[instruction1];
408   } else if (companion_set_index_[instruction1] !=
409              companion_set_index_[instruction2]) {
410     // At any point while building the companion sets, each instruction belongs
411     // to at most 1 companion set, so the union of two companion sets is
412     // concatenating two disjoint sets.
413     absl::c_copy(Companions(instruction2),
414                  std::back_inserter(
415                      *companion_sets_[companion_set_index_[instruction1]]));
416     int64 index_to_remove = companion_set_index_[instruction2];
417     for (HloInstruction* hlo : Companions(instruction2)) {
418       companion_set_index_[hlo] = companion_set_index_[instruction1];
419     }
420     // We can't remove the set from the vector because companion_set_index_
421     // references sets by their index in this vector, so we reset to nullptr
422     // instead.
423     companion_sets_[index_to_remove].reset(nullptr);
424   }
425   return Status::OK();
426 }
427 
VerifyChannelInstructions()428 Status HloModuleGroupMetadata::VerifyChannelInstructions() {
429   for (const Channel& channel : channels_) {
430     if (channel.send == nullptr) {
431       return FailedPrecondition("missing send for id : %d", channel.id);
432     }
433     if (channel.recv == nullptr) {
434       return FailedPrecondition("missing recv for id : %d", channel.id);
435     }
436     if (channel.send_done == nullptr) {
437       return FailedPrecondition("missing send-done for id : %d", channel.id);
438     }
439     if (channel.recv_done == nullptr) {
440       return FailedPrecondition("missing recv-done for id : %d", channel.id);
441     }
442   }
443 
444   // Check if the shapes match for each channel.
445   for (const Channel& channel : channels_) {
446     const Shape& send_shape = channel.send->operand(0)->shape();
447     const Shape& recv_shape =
448         ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0);
449     if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
450       return FailedPrecondition("send/recv shapes do not match");
451     }
452     auto send_device = GetInstructionDevice(*channel.send);
453     auto send_done_device = GetInstructionDevice(*channel.send_done);
454     if (!send_device) {
455       return FailedPrecondition("send instruction must have a device: %s",
456                                 channel.send->ToString());
457     }
458     if (!send_done_device) {
459       return FailedPrecondition("send_done instruction must have a device: %s",
460                                 channel.send_done->ToString());
461     }
462     if (*send_device != *send_done_device) {
463       return FailedPrecondition(
464           "send and send-done (channel=%d) must be on the same device: %d "
465           "vs. %d",
466           channel.id, *send_device, *send_done_device);
467     }
468     auto recv_device = GetInstructionDevice(*channel.recv);
469     auto recv_done_device = GetInstructionDevice(*channel.recv_done);
470     if (!recv_done_device) {
471       return FailedPrecondition("recv_done instruction must have a device: %s",
472                                 channel.recv_done->ToString());
473     }
474     if (*recv_device != *recv_done_device) {
475       return FailedPrecondition(
476           "recv and recv-done (channel=%d) must be on the same device: %d "
477           "vs. %d",
478           channel.id, *recv_device, *recv_done_device);
479     }
480     if (*send_device == *recv_device) {
481       return FailedPrecondition(
482           "send and recv (channel=%d) must be on different devices: %d",
483           channel.id, *send_device);
484     }
485   }
486 
487   for (const Channel& channel : channels_) {
488     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send));
489     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done));
490     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv));
491     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done));
492   }
493   // Check if the nest levels match for each channel.
494   for (const Channel& channel : channels_) {
495     std::vector<TrackedInstruction> path = GetCompanionsPath(channel.send);
496     if (!CheckCompanionPathsCompatibility(
497             path, GetCompanionsPath(channel.send_done)) ||
498         !CheckCompanionPathsCompatibility(path,
499                                           GetCompanionsPath(channel.recv)) ||
500         !CheckCompanionPathsCompatibility(
501             path, GetCompanionsPath(channel.recv_done))) {
502       return FailedPrecondition(
503           "Nest companion paths do not match for channel %d", channel.id);
504     }
505   }
506   return Status::OK();
507 }
508 
CheckCommunicatingInstruction(HloInstruction * instruction) const509 Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
510     HloInstruction* instruction) const {
511   HloComputation* computation = instruction->parent();
512   const HloModule* module = computation->parent();
513   if (module->entry_computation() == computation ||
514       tracked_instructions_.contains(computation)) {
515     return Status::OK();
516   }
517   return FailedPrecondition("channel is used in disallowed computation");
518 }
519 
DumpCollectedStats() const520 void HloModuleGroupMetadata::DumpCollectedStats() const {
521   std::map<std::pair<int64, int64>, int64> communication_histogram;
522   for (auto& channel : channels_) {
523     auto from_device = GetInstructionDevice(*channel.send);
524     auto to_device = GetInstructionDevice(*channel.recv);
525     LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device
526               << " to_device=" << *to_device << " send=" << channel.send->name()
527               << " send_done=" << channel.send_done->name()
528               << " recv=" << channel.recv->name()
529               << " recv_done=" << channel.recv_done->name();
530     communication_histogram[std::pair<int64, int64>(*from_device,
531                                                     *to_device)] += 1;
532   }
533   for (auto& fromto_count : communication_histogram) {
534     LOG(INFO) << "From " << fromto_count.first.first << " to "
535               << fromto_count.first.second << ": " << fromto_count.second;
536   }
537   for (auto& companion_set : companion_sets_) {
538     LOG(INFO) << "Companion set:";
539     for (HloInstruction* instruction : *companion_set) {
540       LOG(INFO) << "  " << instruction->name();
541     }
542   }
543   for (auto& instruction_comm : tracked_instructions_comms_) {
544     LOG(INFO) << "Communicating instruction " << instruction_comm.first->name();
545     for (HloInstruction* instruction : instruction_comm.second) {
546       auto device = GetInstructionDevice(*instruction);
547       LOG(INFO) << "  " << instruction->name() << " on device " << *device;
548     }
549   }
550 }
551 
552 }  // namespace xla
553