1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
17
18 #include <queue>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30
31 namespace xla {
32
CreateFromProto(const HloModule * module,const HloScheduleProto & proto)33 /* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
34 const HloModule* module, const HloScheduleProto& proto) {
35 absl::flat_hash_map<int64, const HloComputation*> id_to_computation;
36 for (const HloComputation* computation : module->computations()) {
37 id_to_computation[computation->unique_id()] = computation;
38 }
39
40 HloSchedule schedule(module);
41 for (const auto& id_sequence : proto.sequences()) {
42 int64 computation_id = id_sequence.first;
43
44 auto comp_it = id_to_computation.find(computation_id);
45 TF_RET_CHECK(comp_it != id_to_computation.end())
46 << "No computation exists in HLO module with id " << computation_id;
47 const HloComputation* computation = comp_it->second;
48
49 absl::flat_hash_map<int64, HloInstruction*> id_to_instruction;
50 for (HloInstruction* instruction : computation->instructions()) {
51 id_to_instruction[instruction->unique_id()] = instruction;
52 }
53
54 HloInstructionSequence& sequence =
55 schedule.GetOrCreateSequence(computation);
56 for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
57 auto instr_it = id_to_instruction.find(instruction_id);
58 TF_RET_CHECK(instr_it != id_to_instruction.end())
59 << "No instruction exists in HLO computation " << computation->name()
60 << " with id " << instruction_id;
61 sequence.push_back(instr_it->second);
62 }
63 }
64 TF_RETURN_IF_ERROR(schedule.Verify());
65 return std::move(schedule);
66 }
67
ToProto() const68 StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
69 TF_RETURN_IF_ERROR(Verify());
70 HloScheduleProto proto;
71 for (const auto& id_sequence : sequences_) {
72 int64 computation_id = id_sequence.first;
73 const HloInstructionSequence& sequence = id_sequence.second;
74 HloScheduleProto::InstructionSequence& proto_sequence =
75 (*proto.mutable_sequences())[computation_id];
76 proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
77 for (const int64 id : sequence.ids()) {
78 proto_sequence.add_instruction_ids(id);
79 }
80 }
81 return std::move(proto);
82 }
83
set_sequence(const HloComputation * computation,absl::Span<HloInstruction * const> sequence)84 void HloSchedule::set_sequence(const HloComputation* computation,
85 absl::Span<HloInstruction* const> sequence) {
86 set_sequence(computation, HloInstructionSequence(sequence));
87 }
88
set_sequence(const HloComputation * computation,HloInstructionSequence sequence)89 void HloSchedule::set_sequence(const HloComputation* computation,
90 HloInstructionSequence sequence) {
91 CHECK(computation->parent() == module_);
92 sequences_[computation->unique_id()] = std::move(sequence);
93 }
94
GetOrCreateSequence(const HloComputation * computation)95 HloInstructionSequence& HloSchedule::GetOrCreateSequence(
96 const HloComputation* computation) {
97 auto it = sequences_.find(computation->unique_id());
98 if (it == sequences_.end()) {
99 // No sequence found for computation. Create and return an empty one.
100 CHECK(computation->parent() == module_);
101 return sequences_[computation->unique_id()];
102 } else {
103 return it->second;
104 }
105 }
106
sequence(const HloComputation * computation) const107 const HloInstructionSequence& HloSchedule::sequence(
108 const HloComputation* computation) const {
109 return sequences_.at(computation->unique_id());
110 }
111
UpdateComputationSchedule(const HloComputation * computation)112 Status HloSchedule::UpdateComputationSchedule(
113 const HloComputation* computation) {
114 // Map from unique ID to HloInstruction pointer for instructions in the
115 // computation.
116 absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
117 for (HloInstruction* instruction : computation->instructions()) {
118 InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
119 }
120
121 // Set of all HloInstructions in the schedule.
122 absl::flat_hash_set<int> ids_in_schedule;
123 for (int id : sequences_.at(computation->unique_id()).ids()) {
124 InsertOrDie(&ids_in_schedule, id);
125 }
126
127 // Map from HloInstruction X to newly added instructions (instruction is in
128 // computation, but not in schedule) which use X. If an instruction is not in
129 // the map, then it has no users which are newly added instructions.
130 absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
131 new_instruction_uses;
132
133 // For each newly added instruction, this is the count of the instruction's
134 // operands that have not yet been scheduled. When this value reaches zero,
135 // then the instruction may be placed in the schedule.
136 absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count;
137
138 // Create a worklist of newly added instructions which are ready to be added
139 // to the schedule. Initialize worklist with those that have zero operands.
140 std::queue<HloInstruction*> worklist;
141
142 for (HloInstruction* instruction : computation->instructions()) {
143 if (!ids_in_schedule.contains(instruction->unique_id())) {
144 // This is a newly added instruction which is not in the schedule.
145 if (instruction->operands().empty()) {
146 worklist.push(instruction);
147 } else {
148 for (const HloInstruction* operand : instruction->operands()) {
149 new_instruction_uses[operand].push_back(instruction);
150 }
151 unscheduled_operand_count[instruction] = instruction->operand_count();
152 }
153 }
154 }
155
156 // Update the schedule with the newly added instructions, and remove any
157 // instructions no longer in the graph.
158 HloInstructionSequence new_sequence;
159
160 // Lambda which schedules all instructions on the worklist.
161 auto schedule_worklist = [&]() {
162 while (!worklist.empty()) {
163 HloInstruction* instruction = worklist.front();
164 worklist.pop();
165 new_sequence.push_back(instruction);
166 std::vector<HloInstruction*>* new_users =
167 tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
168 if (new_users != nullptr) {
169 // This just-scheduled instruction has users which are newly added to
170 // the module. Update the number of unscheduled operands and push the
171 // newly added instruction to the worklist if it is ready to
172 // schedule.
173 for (HloInstruction* new_user : *new_users) {
174 unscheduled_operand_count.at(new_user)--;
175 CHECK_GE(unscheduled_operand_count.at(new_user), 0);
176 if (unscheduled_operand_count.at(new_user) == 0) {
177 worklist.push(new_user);
178 }
179 }
180 }
181 }
182 };
183
184 schedule_worklist();
185 for (int id : sequences_.at(computation->unique_id()).ids()) {
186 auto it = id_to_instruction.find(id);
187 if (it == id_to_instruction.end()) {
188 // This instruction in the schedule is no longer in the module. Do not add
189 // it to the new schedule.
190 continue;
191 }
192 worklist.push(it->second);
193 schedule_worklist();
194 }
195
196 set_sequence(computation, std::move(new_sequence));
197 return Status::OK();
198 }
199
Update()200 Status HloSchedule::Update() {
201 // The schedule must contain a sequence for every non-fusion computation in
202 // the module, but can have sequences for computations which no longer exist
203 // (these are removed).
204 std::vector<HloComputation*> nonfusion_computations =
205 module_->MakeNonfusionComputations();
206 for (const HloComputation* computation : nonfusion_computations) {
207 TF_RET_CHECK(sequences_.contains(computation->unique_id()))
208 << "Computation " << computation->name() << " not in HloSchedule.";
209 }
210 if (sequences_.size() > nonfusion_computations.size()) {
211 // Schedule contains some computations which have been removed from the
212 // HloModule. Remove them from the schedule as well.
213 absl::flat_hash_set<int64> nonfusion_computations_ids;
214 for (const HloComputation* computation : nonfusion_computations) {
215 nonfusion_computations_ids.insert(computation->unique_id());
216 }
217 for (auto it = sequences_.begin(); it != sequences_.end();) {
218 if (!nonfusion_computations_ids.contains(it->first)) {
219 sequences_.erase(it++);
220 } else {
221 ++it;
222 }
223 }
224 }
225 CHECK_EQ(sequences_.size(), nonfusion_computations.size());
226
227 for (const HloComputation* computation : nonfusion_computations) {
228 TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
229 }
230
231 TF_RETURN_IF_ERROR(Verify());
232 return Status::OK();
233 }
234
Verify() const235 Status HloSchedule::Verify() const {
236 VLOG(2) << "VerifySchedule()";
237 XLA_VLOG_LINES(2, ToString());
238
239 // Verify schedule contains exactly the same set of non-fusion computations as
240 // module currently does.
241 std::vector<HloComputation*> nonfusion_computations =
242 module_->MakeNonfusionComputations();
243 TF_RET_CHECK(nonfusion_computations.size() == sequences_.size())
244 << "Schedule has " << sequences_.size() << " sequences, but module has "
245 << nonfusion_computations.size() << " non-fusion computations";
246 for (const HloComputation* computation : nonfusion_computations) {
247 TF_RET_CHECK(sequences_.contains(computation->unique_id()))
248 << "Computation " << computation->name()
249 << " missing from HLO schedule.";
250 }
251
252 // For each computation verify the set of instructions is the same and that
253 // each dependency and control edge is honored.
254 for (const HloComputation* computation : nonfusion_computations) {
255 absl::flat_hash_map<const HloInstruction*, int> instruction_position;
256 int pos = 0;
257 for (const HloInstruction* instruction :
258 sequence(computation).instructions()) {
259 TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
260 << "Instruction " << instruction->name()
261 << " appears more than once in the schedule";
262 pos++;
263 }
264
265 TF_RET_CHECK(instruction_position.size() ==
266 computation->instruction_count())
267 << "Schedule for computation " << computation->name() << " has "
268 << instruction_position.size() << " instructions, expected "
269 << computation->instruction_count();
270 for (const HloInstruction* instruction : computation->instructions()) {
271 TF_RET_CHECK(instruction_position.contains(instruction))
272 << "Instruction " << instruction->name() << " is not in schedule";
273 }
274
275 for (const HloInstruction* instruction : computation->instructions()) {
276 for (const HloInstruction* operand : instruction->operands()) {
277 TF_RET_CHECK(instruction_position.at(operand) <
278 instruction_position.at(instruction))
279 << "Instruction " << instruction->name()
280 << " is not scheduled after its operand " << operand->name();
281 }
282
283 for (const HloInstruction* pred : instruction->control_predecessors()) {
284 TF_RET_CHECK(instruction_position.at(pred) <
285 instruction_position.at(instruction))
286 << "Instruction " << instruction->name()
287 << " is not scheduled after its control predecessor "
288 << pred->name();
289 }
290 }
291 }
292
293 return Status::OK();
294 }
295
296 namespace {
297
298 // Returns the computation in the given module with the given unique ID. Returns
299 // nullptr if no such computation exists.
IdToComputation(const HloModule * module,int64 id)300 const HloComputation* IdToComputation(const HloModule* module, int64 id) {
301 for (const HloComputation* computation : module->computations()) {
302 if (computation->unique_id() == id) {
303 return computation;
304 }
305 }
306 return nullptr;
307 }
308
309 } // namespace
310
ToString() const311 string HloSchedule::ToString() const {
312 std::vector<string> pieces;
313
314 pieces.push_back("HloSchedule");
315 for (const auto& id_sequence : sequences_) {
316 const HloComputation* computation =
317 IdToComputation(module_, id_sequence.first);
318 if (computation == nullptr) {
319 // The computation is not in the module and may have been deleted so it is
320 // not safe to dereference any HLO pointers. Just use the HLO unique ids
321 // stored in this object.
322 pieces.push_back(
323 absl::StrFormat("computation with id %d (no longer in HLO module):",
324 id_sequence.first));
325 for (int id : id_sequence.second.ids()) {
326 pieces.push_back(absl::StrCat(" ", id));
327 }
328 } else {
329 pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
330 for (const HloInstruction* instruction :
331 id_sequence.second.instructions()) {
332 pieces.push_back(absl::StrCat(" ", instruction->name()));
333 }
334 }
335 }
336 return absl::StrJoin(pieces, "\n");
337 }
338
operator <<(std::ostream & out,const HloSchedule & schedule)339 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
340 out << schedule.ToString();
341 return out;
342 }
343
344 } // namespace xla
345