1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
17 
18 #include <queue>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 
31 namespace xla {
32 
CreateFromProto(const HloModule * module,const HloScheduleProto & proto)33 /* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
34     const HloModule* module, const HloScheduleProto& proto) {
35   absl::flat_hash_map<int64, const HloComputation*> id_to_computation;
36   for (const HloComputation* computation : module->computations()) {
37     id_to_computation[computation->unique_id()] = computation;
38   }
39 
40   HloSchedule schedule(module);
41   for (const auto& id_sequence : proto.sequences()) {
42     int64 computation_id = id_sequence.first;
43 
44     auto comp_it = id_to_computation.find(computation_id);
45     TF_RET_CHECK(comp_it != id_to_computation.end())
46         << "No computation exists in HLO module with id " << computation_id;
47     const HloComputation* computation = comp_it->second;
48 
49     absl::flat_hash_map<int64, HloInstruction*> id_to_instruction;
50     for (HloInstruction* instruction : computation->instructions()) {
51       id_to_instruction[instruction->unique_id()] = instruction;
52     }
53 
54     HloInstructionSequence& sequence =
55         schedule.GetOrCreateSequence(computation);
56     for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
57       auto instr_it = id_to_instruction.find(instruction_id);
58       TF_RET_CHECK(instr_it != id_to_instruction.end())
59           << "No instruction exists in HLO computation " << computation->name()
60           << " with id " << instruction_id;
61       sequence.push_back(instr_it->second);
62     }
63   }
64   TF_RETURN_IF_ERROR(schedule.Verify());
65   return std::move(schedule);
66 }
67 
ToProto() const68 StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
69   TF_RETURN_IF_ERROR(Verify());
70   HloScheduleProto proto;
71   for (const auto& id_sequence : sequences_) {
72     int64 computation_id = id_sequence.first;
73     const HloInstructionSequence& sequence = id_sequence.second;
74     HloScheduleProto::InstructionSequence& proto_sequence =
75         (*proto.mutable_sequences())[computation_id];
76     proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
77     for (const int64 id : sequence.ids()) {
78       proto_sequence.add_instruction_ids(id);
79     }
80   }
81   return std::move(proto);
82 }
83 
set_sequence(const HloComputation * computation,absl::Span<HloInstruction * const> sequence)84 void HloSchedule::set_sequence(const HloComputation* computation,
85                                absl::Span<HloInstruction* const> sequence) {
86   set_sequence(computation, HloInstructionSequence(sequence));
87 }
88 
set_sequence(const HloComputation * computation,HloInstructionSequence sequence)89 void HloSchedule::set_sequence(const HloComputation* computation,
90                                HloInstructionSequence sequence) {
91   CHECK(computation->parent() == module_);
92   sequences_[computation->unique_id()] = std::move(sequence);
93 }
94 
GetOrCreateSequence(const HloComputation * computation)95 HloInstructionSequence& HloSchedule::GetOrCreateSequence(
96     const HloComputation* computation) {
97   auto it = sequences_.find(computation->unique_id());
98   if (it == sequences_.end()) {
99     // No sequence found for computation. Create and return an empty one.
100     CHECK(computation->parent() == module_);
101     return sequences_[computation->unique_id()];
102   } else {
103     return it->second;
104   }
105 }
106 
sequence(const HloComputation * computation) const107 const HloInstructionSequence& HloSchedule::sequence(
108     const HloComputation* computation) const {
109   return sequences_.at(computation->unique_id());
110 }
111 
UpdateComputationSchedule(const HloComputation * computation)112 Status HloSchedule::UpdateComputationSchedule(
113     const HloComputation* computation) {
114   // Map from unique ID to HloInstruction pointer for instructions in the
115   // computation.
116   absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
117   for (HloInstruction* instruction : computation->instructions()) {
118     InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
119   }
120 
121   // Set of all HloInstructions in the schedule.
122   absl::flat_hash_set<int> ids_in_schedule;
123   for (int id : sequences_.at(computation->unique_id()).ids()) {
124     InsertOrDie(&ids_in_schedule, id);
125   }
126 
127   // Map from HloInstruction X to newly added instructions (instruction is in
128   // computation, but not in schedule) which use X. If an instruction is not in
129   // the map, then it has no users which are newly added instructions.
130   absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
131       new_instruction_uses;
132 
133   // For each newly added instruction, this is the count of the instruction's
134   // operands that have not yet been scheduled. When this value reaches zero,
135   // then the instruction may be placed in the schedule.
136   absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count;
137 
138   // Create a worklist of newly added instructions which are ready to be added
139   // to the schedule. Initialize worklist with those that have zero operands.
140   std::queue<HloInstruction*> worklist;
141 
142   for (HloInstruction* instruction : computation->instructions()) {
143     if (!ids_in_schedule.contains(instruction->unique_id())) {
144       // This is a newly added instruction which is not in the schedule.
145       if (instruction->operands().empty()) {
146         worklist.push(instruction);
147       } else {
148         for (const HloInstruction* operand : instruction->operands()) {
149           new_instruction_uses[operand].push_back(instruction);
150         }
151         unscheduled_operand_count[instruction] = instruction->operand_count();
152       }
153     }
154   }
155 
156   // Update the schedule with the newly added instructions, and remove any
157   // instructions no longer in the graph.
158   HloInstructionSequence new_sequence;
159 
160   // Lambda which schedules all instructions on the worklist.
161   auto schedule_worklist = [&]() {
162     while (!worklist.empty()) {
163       HloInstruction* instruction = worklist.front();
164       worklist.pop();
165       new_sequence.push_back(instruction);
166       std::vector<HloInstruction*>* new_users =
167           tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
168       if (new_users != nullptr) {
169         // This just-scheduled instruction has users which are newly added to
170         // the module. Update the number of unscheduled operands and push the
171         // newly added instruction to the worklist if it is ready to
172         // schedule.
173         for (HloInstruction* new_user : *new_users) {
174           unscheduled_operand_count.at(new_user)--;
175           CHECK_GE(unscheduled_operand_count.at(new_user), 0);
176           if (unscheduled_operand_count.at(new_user) == 0) {
177             worklist.push(new_user);
178           }
179         }
180       }
181     }
182   };
183 
184   schedule_worklist();
185   for (int id : sequences_.at(computation->unique_id()).ids()) {
186     auto it = id_to_instruction.find(id);
187     if (it == id_to_instruction.end()) {
188       // This instruction in the schedule is no longer in the module. Do not add
189       // it to the new schedule.
190       continue;
191     }
192     worklist.push(it->second);
193     schedule_worklist();
194   }
195 
196   set_sequence(computation, std::move(new_sequence));
197   return Status::OK();
198 }
199 
Update()200 Status HloSchedule::Update() {
201   // The schedule must contain a sequence for every non-fusion computation in
202   // the module, but can have sequences for computations which no longer exist
203   // (these are removed).
204   std::vector<HloComputation*> nonfusion_computations =
205       module_->MakeNonfusionComputations();
206   for (const HloComputation* computation : nonfusion_computations) {
207     TF_RET_CHECK(sequences_.contains(computation->unique_id()))
208         << "Computation " << computation->name() << " not in HloSchedule.";
209   }
210   if (sequences_.size() > nonfusion_computations.size()) {
211     // Schedule contains some computations which have been removed from the
212     // HloModule. Remove them from the schedule as well.
213     absl::flat_hash_set<int64> nonfusion_computations_ids;
214     for (const HloComputation* computation : nonfusion_computations) {
215       nonfusion_computations_ids.insert(computation->unique_id());
216     }
217     for (auto it = sequences_.begin(); it != sequences_.end();) {
218       if (!nonfusion_computations_ids.contains(it->first)) {
219         sequences_.erase(it++);
220       } else {
221         ++it;
222       }
223     }
224   }
225   CHECK_EQ(sequences_.size(), nonfusion_computations.size());
226 
227   for (const HloComputation* computation : nonfusion_computations) {
228     TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
229   }
230 
231   TF_RETURN_IF_ERROR(Verify());
232   return Status::OK();
233 }
234 
Verify() const235 Status HloSchedule::Verify() const {
236   VLOG(2) << "VerifySchedule()";
237   XLA_VLOG_LINES(2, ToString());
238 
239   // Verify schedule contains exactly the same set of non-fusion computations as
240   // module currently does.
241   std::vector<HloComputation*> nonfusion_computations =
242       module_->MakeNonfusionComputations();
243   TF_RET_CHECK(nonfusion_computations.size() == sequences_.size())
244       << "Schedule has " << sequences_.size() << " sequences, but module has "
245       << nonfusion_computations.size() << " non-fusion computations";
246   for (const HloComputation* computation : nonfusion_computations) {
247     TF_RET_CHECK(sequences_.contains(computation->unique_id()))
248         << "Computation " << computation->name()
249         << " missing from HLO schedule.";
250   }
251 
252   // For each computation verify the set of instructions is the same and that
253   // each dependency and control edge is honored.
254   for (const HloComputation* computation : nonfusion_computations) {
255     absl::flat_hash_map<const HloInstruction*, int> instruction_position;
256     int pos = 0;
257     for (const HloInstruction* instruction :
258          sequence(computation).instructions()) {
259       TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
260           << "Instruction " << instruction->name()
261           << " appears more than once in the schedule";
262       pos++;
263     }
264 
265     TF_RET_CHECK(instruction_position.size() ==
266                  computation->instruction_count())
267         << "Schedule for computation " << computation->name() << " has "
268         << instruction_position.size() << " instructions, expected "
269         << computation->instruction_count();
270     for (const HloInstruction* instruction : computation->instructions()) {
271       TF_RET_CHECK(instruction_position.contains(instruction))
272           << "Instruction " << instruction->name() << " is not in schedule";
273     }
274 
275     for (const HloInstruction* instruction : computation->instructions()) {
276       for (const HloInstruction* operand : instruction->operands()) {
277         TF_RET_CHECK(instruction_position.at(operand) <
278                      instruction_position.at(instruction))
279             << "Instruction " << instruction->name()
280             << " is not scheduled after its operand " << operand->name();
281       }
282 
283       for (const HloInstruction* pred : instruction->control_predecessors()) {
284         TF_RET_CHECK(instruction_position.at(pred) <
285                      instruction_position.at(instruction))
286             << "Instruction " << instruction->name()
287             << " is not scheduled after its control predecessor "
288             << pred->name();
289       }
290     }
291   }
292 
293   return Status::OK();
294 }
295 
296 namespace {
297 
298 // Returns the computation in the given module with the given unique ID. Returns
299 // nullptr if no such computation exists.
IdToComputation(const HloModule * module,int64 id)300 const HloComputation* IdToComputation(const HloModule* module, int64 id) {
301   for (const HloComputation* computation : module->computations()) {
302     if (computation->unique_id() == id) {
303       return computation;
304     }
305   }
306   return nullptr;
307 }
308 
309 }  // namespace
310 
ToString() const311 string HloSchedule::ToString() const {
312   std::vector<string> pieces;
313 
314   pieces.push_back("HloSchedule");
315   for (const auto& id_sequence : sequences_) {
316     const HloComputation* computation =
317         IdToComputation(module_, id_sequence.first);
318     if (computation == nullptr) {
319       // The computation is not in the module and may have been deleted so it is
320       // not safe to dereference any HLO pointers. Just use the HLO unique ids
321       // stored in this object.
322       pieces.push_back(
323           absl::StrFormat("computation with id %d (no longer in HLO module):",
324                           id_sequence.first));
325       for (int id : id_sequence.second.ids()) {
326         pieces.push_back(absl::StrCat("  ", id));
327       }
328     } else {
329       pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
330       for (const HloInstruction* instruction :
331            id_sequence.second.instructions()) {
332         pieces.push_back(absl::StrCat("  ", instruction->name()));
333       }
334     }
335   }
336   return absl::StrJoin(pieces, "\n");
337 }
338 
operator <<(std::ostream & out,const HloSchedule & schedule)339 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
340   out << schedule.ToString();
341   return out;
342 }
343 
344 }  // namespace xla
345