1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
18 
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
26 #include "tensorflow/compiler/xla/status.h"
27 
28 namespace xla {
29 
30 class HloModule;
31 
32 // Class representing a sequence of HLO instructions such as the sequential
33 // execution order of an HLO computation.
34 class HloInstructionSequence {
35  public:
36   HloInstructionSequence() = default;
HloInstructionSequence(absl::Span<HloInstruction * const> instructions)37   explicit HloInstructionSequence(
38       absl::Span<HloInstruction* const> instructions) {
39     for (HloInstruction* instruction : instructions) {
40       push_back(instruction);
41     }
42   }
43 
44   // Adds the instruction to the end of the sequence.
push_back(HloInstruction * instruction)45   void push_back(HloInstruction* instruction) {
46     instruction_sequence_.push_back(instruction);
47     id_sequence_.push_back(instruction->unique_id());
48   }
49 
50   // Clears the sequence of all instructions.
clear()51   void clear() {
52     instruction_sequence_.clear();
53     id_sequence_.clear();
54   }
55 
size()56   int64 size() const { return instruction_sequence_.size(); }
57 
58   // Returns the sequence of HLO instructions.
instructions()59   const std::vector<HloInstruction*>& instructions() const {
60     return instruction_sequence_;
61   }
62 
63   // Returns the unique IDs of the instructions in the sequence (in order).
ids()64   const std::vector<int>& ids() const { return id_sequence_; }
65 
66  private:
67   // The sequence as HloInstructions.
68   std::vector<HloInstruction*> instruction_sequence_;
69 
70   // The sequence of HLO instructions, represented by their unique IDs. The
71   // sequence is stored as both HloInstructions and unique IDs because the
72   // sequence may be referenced after transformations to the HLO graph and HLO
73   // pointers can be invalidated or recycled in this process (see
74   // HloSchedule::Update).
75   std::vector<int> id_sequence_;
76 };
77 
78 // A class representing a sequential schedule of instructions for an HLO
79 // module. A complete HLO schedule contains an instruction sequence for every
80 // non-fusion computation in the HLO module.
81 class HloSchedule {
82  public:
HloSchedule(const HloModule * module)83   explicit HloSchedule(const HloModule* module) : module_(module) {}
84 
85   // (De)Serialize an HloSchedule to/from a HloScheduleProto.
86   static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
87                                                const HloScheduleProto& proto);
88   StatusOr<HloScheduleProto> ToProto() const;
89 
90   // Returns a reference to the sequence for the given computation.
91   const HloInstructionSequence& sequence(
92       const HloComputation* computation) const;
93 
94   // Returns the sequence for the given computation. An empty sequence is
95   // created if none exists for the computation.
96   HloInstructionSequence& GetOrCreateSequence(
97       const HloComputation* computation);
98 
99   // Sets the sequence for the given computation to the given sequence.
100   void set_sequence(const HloComputation* computation,
101                     absl::Span<HloInstruction* const> sequence);
102   void set_sequence(const HloComputation* computation,
103                     HloInstructionSequence sequence);
104 
105   // Returns a map from HloComputation unique ID to instruction sequence. The
106   // map contains all sequences in the schedule.
sequences()107   const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const {
108     return sequences_;
109   }
110 
111   // Returns true if the schedule has a sequence for the given computation.
is_computation_scheduled(const HloComputation * computation)112   bool is_computation_scheduled(const HloComputation* computation) const {
113     return sequences_.contains(computation->unique_id());
114   }
115 
116   // Updates the schedule such that it is (again) a valid schedule for the
117   // module. This is used to update a schedule after the HLO module has been
118   // transformed in some way. In general, the only transformations to the module
119   // for which a schedule can be updated is the addition or removal of
120   // instructions and removal of computations. Updating the schedule after new
121   // dependencies between existing instructions in the module is not supported
122   // and may result in an error status returned.
123   //
124   // Instructions in the module which also exist in the given schedule will
125   // remain in the same order in the updated schedule. Instructions which exist
126   // in the module but not in the given schedule will be placed as early as
127   // possible in the updated schedule.
128   Status Update();
129 
130   // Verifies that the given schedule is valid for the given module.
131   // Specifically, the schedule contains exactly the instructions in the
132   // non-fusion computations in the module and every dependency in the module is
133   // satisfied in the schedule.
134   Status Verify() const;
135 
136   string ToString() const;
137 
empty()138   bool empty() const { return sequences_.empty(); }
139 
module()140   const HloModule* module() const { return module_; }
141 
142  private:
143   // Updates the instruction sequence for the given computation.
144   Status UpdateComputationSchedule(const HloComputation* computation);
145 
146   const HloModule* module_;
147 
148   // A map from computation unique ID to instruction sequence. Unique IDs are
149   // used rather than HloComputation pointers because HLO pointers are not
150   // unique across HLO transformations because pointers may be recycled.
151   absl::flat_hash_map<int64, HloInstructionSequence> sequences_;
152 };
153 
154 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
155 
156 }  // namespace xla
157 
158 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
159