1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
18 
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/status.h"
26 
27 namespace xla {
28 
29 class HloModule;
30 
31 // Class representing a sequence of HLO instructions such as the sequential
32 // execution order of an HLO computation.
33 class HloInstructionSequence {
34  public:
35   HloInstructionSequence() = default;
HloInstructionSequence(absl::Span<HloInstruction * const> instructions)36   explicit HloInstructionSequence(
37       absl::Span<HloInstruction* const> instructions) {
38     for (HloInstruction* instruction : instructions) {
39       push_back(instruction);
40     }
41   }
42 
43   // Adds the instruction to the end of the sequence.
push_back(HloInstruction * instruction)44   void push_back(HloInstruction* instruction) {
45     instruction_sequence_.push_back(instruction);
46     id_sequence_.push_back(instruction->unique_id());
47   }
48 
49   // Removes the instruction from the sequence.
remove_instruction(HloInstruction * instruction)50   void remove_instruction(HloInstruction* instruction) {
51     auto instruction_it = std::find(instruction_sequence_.begin(),
52                                     instruction_sequence_.end(), instruction);
53     if (instruction_it != instruction_sequence_.end()) {
54       auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(),
55                              instruction->unique_id());
56       instruction_sequence_.erase(instruction_it);
57       id_sequence_.erase(id_it);
58     }
59   }
60 
61   // Replaces the old instruction with the new instruction in the sequence.
replace_instruction(HloInstruction * old_instruction,HloInstruction * new_instruction)62   void replace_instruction(HloInstruction* old_instruction,
63                            HloInstruction* new_instruction) {
64     auto instruction_it =
65         std::find(instruction_sequence_.begin(), instruction_sequence_.end(),
66                   old_instruction);
67     auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(),
68                            old_instruction->unique_id());
69     CHECK(instruction_it != instruction_sequence_.end())
70         << "Do not find instruction id " << old_instruction->unique_id();
71     CHECK(id_it != id_sequence_.end());
72     *instruction_it = new_instruction;
73     *id_it = new_instruction->unique_id();
74   }
75 
76   // Clears the sequence of all instructions.
clear()77   void clear() {
78     instruction_sequence_.clear();
79     id_sequence_.clear();
80   }
81 
size()82   int64 size() const { return instruction_sequence_.size(); }
83 
84   // Returns the sequence of HLO instructions.
instructions()85   const std::vector<HloInstruction*>& instructions() const {
86     return instruction_sequence_;
87   }
88 
89   // Returns the unique IDs of the instructions in the sequence (in order).
ids()90   const std::vector<int>& ids() const { return id_sequence_; }
91 
92  private:
93   // The sequence as HloInstructions.
94   std::vector<HloInstruction*> instruction_sequence_;
95 
96   // The sequence of HLO instructions, represented by their unique IDs. The
97   // sequence is stored as both HloInstructions and unique IDs because the
98   // sequence may be referenced after transformations to the HLO graph and HLO
99   // pointers can be invalidated or recycled in this process (see
100   // HloSchedule::Update).
101   std::vector<int> id_sequence_;
102 };
103 
104 // A class representing a sequential schedule of instructions for an HLO
105 // module. A complete HLO schedule contains an instruction sequence for every
106 // non-fusion computation in the HLO module.
107 class HloSchedule {
108  public:
HloSchedule(const HloModule * module)109   explicit HloSchedule(const HloModule* module) : module_(module) {}
110 
111   // (De)Serialize an HloSchedule to/from a HloScheduleProto.
112   static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
113                                                const HloScheduleProto& proto);
114   StatusOr<HloScheduleProto> ToProto() const;
115 
116   // Returns a reference to the sequence for the given computation.
117   const HloInstructionSequence& sequence(
118       const HloComputation* computation) const;
119 
120   // Returns the sequence for the given computation. An empty sequence is
121   // created if none exists for the computation.
122   HloInstructionSequence& GetOrCreateSequence(
123       const HloComputation* computation);
124 
125   // Sets the sequence for the given computation to the given sequence.
126   void set_sequence(const HloComputation* computation,
127                     absl::Span<HloInstruction* const> sequence);
128   void set_sequence(const HloComputation* computation,
129                     HloInstructionSequence sequence);
130 
131   // Returns a map from HloComputation unique ID to instruction sequence. The
132   // map contains all sequences in the schedule.
sequences()133   const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const {
134     return sequences_;
135   }
136 
137   // Returns true if the schedule has a sequence for the given computation.
is_computation_scheduled(const HloComputation * computation)138   bool is_computation_scheduled(const HloComputation* computation) const {
139     return sequences_.contains(computation->unique_id());
140   }
141 
142   // Removes the computation from the sequences.
remove_computation(const HloComputation * computation)143   void remove_computation(const HloComputation* computation) {
144     auto it = sequences_.find(computation->unique_id());
145     CHECK(it != sequences_.end());
146     sequences_.erase(it);
147   }
148 
149   // Removes the instruction from the computation's sequence.
remove_instruction(const HloComputation * computation,HloInstruction * instruction)150   void remove_instruction(const HloComputation* computation,
151                           HloInstruction* instruction) {
152     sequences_[computation->unique_id()].remove_instruction(instruction);
153   }
154 
155   // Replaces the old instruction with the new instruction in the computation's
156   // sequence.
replace_instruction(const HloComputation * computation,HloInstruction * old_instruction,HloInstruction * new_instruction)157   void replace_instruction(const HloComputation* computation,
158                            HloInstruction* old_instruction,
159                            HloInstruction* new_instruction) {
160     sequences_[computation->unique_id()].replace_instruction(old_instruction,
161                                                              new_instruction);
162   }
163 
164   // Updates the schedule such that it is (again) a valid schedule for the
165   // module. This is used to update a schedule after the HLO module has been
166   // transformed in some way. In general, the only transformations to the module
167   // for which a schedule can be updated is the addition or removal of
168   // instructions and removal of computations. Updating the schedule after new
169   // dependencies between existing instructions in the module is not supported
170   // and may result in an error status returned.
171   //
172   // Instructions in the module which also exist in the given schedule will
173   // remain in the same order in the updated schedule. Instructions which exist
174   // in the module but not in the given schedule will be placed as early as
175   // possible in the updated schedule.
176   Status Update();
177 
178   // Verifies that the given schedule is valid for the given module.
179   // Specifically, the schedule contains exactly the instructions in the
180   // non-fusion computations in the module and every dependency in the module is
181   // satisfied in the schedule.
182   Status Verify() const;
183 
184   string ToString() const;
185 
empty()186   bool empty() const { return sequences_.empty(); }
187 
module()188   const HloModule* module() const { return module_; }
189 
190  private:
191   // Updates the instruction sequence for the given computation.
192   Status UpdateComputationSchedule(const HloComputation* computation);
193 
194   const HloModule* module_;
195 
196   // A map from computation unique ID to instruction sequence. Unique IDs are
197   // used rather than HloComputation pointers because HLO pointers are not
198   // unique across HLO transformations because pointers may be recycled.
199   absl::flat_hash_map<int64, HloInstructionSequence> sequences_;
200 };
201 
202 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
203 
204 }  // namespace xla
205 
206 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
207