1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 License for the specific language governing permissions and limitations under
13 the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 
33 namespace xla {
34 
35 // Class which computes live range of the output buffers of HLOs and their
36 // interference by flattening all computations. The live range is only available
37 // when all global computations (while, if, call, etc) have total order
38 // sequential orders.
39 class HloLiveRange {
40  public:
41   // Constructs a hlo live range object for the given module and computation
42   // assuming the given HLO instruction ordering.
43   static StatusOr<std::unique_ptr<HloLiveRange>> Run(
44       const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
45       const HloComputation* computation, bool module_scoped_analysis = true);
46 
47   // LogicalTime represents the time in a virtual clock. Each instruction has
48   // one monotonically increasing logical time assigned according to the
49   // schedule.
50   using LogicalTime = int64;
51 
52   struct TimeBound {
53     LogicalTime start;
54     LogicalTime end;
55 
56     bool friend operator==(const TimeBound& a, const TimeBound& b) {
57       return a.start == b.start && a.end == b.end;
58     }
59     bool friend operator!=(const TimeBound& a, const TimeBound& b) {
60       return !(a == b);
61     }
62   };
63 
64   std::string ToString() const;
65 
flattened_instruction_sequence()66   const HloInstructionSequence& flattened_instruction_sequence() const {
67     return flattened_instruction_sequence_;
68   }
69 
70   // Returns the map from instruction to the end time of that instruction.
71   const absl::flat_hash_map<const HloInstruction*, LogicalTime>&
instruction_schedule()72   instruction_schedule() const {
73     return instruction_schedule_;
74   }
75 
76   // Returns the map from a hlo value to the definition time of that hlo value.
buffer_live_ranges()77   const absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges()
78       const {
79     return buffer_live_ranges_;
80   }
81 
buffer_live_ranges()82   absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() {
83     return buffer_live_ranges_;
84   }
85 
86   // Returns the map from a computation and its time span in the schedule.
87   const absl::flat_hash_map<const HloComputation*, TimeBound>&
computation_span_times()88   computation_span_times() const {
89     return computation_span_times_;
90   }
91 
92   // Returns the time stamp of the end of the program.
schedule_end_time()93   LogicalTime schedule_end_time() const { return schedule_end_time_; }
94 
95   // Returns whether hlo live range is available on this entire module. Hlo live
96   // range is not available if the module is partially ordered.
total_order_scheduled()97   bool total_order_scheduled() const { return total_order_scheduled_; }
98 
99  private:
HloLiveRange(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,bool module_scoped_analysis)100   explicit HloLiveRange(const HloSchedule& schedule,
101                         const HloAliasAnalysis& alias_analysis,
102                         bool module_scoped_analysis)
103       : schedule_(schedule),
104         alias_analysis_(alias_analysis),
105         module_scoped_analysis_(module_scoped_analysis) {}
106 
107   // FlattenSchedule walks through the instructions in `computation`, and
108   // recurse into each called computations in module_scoped_analysis mode. As it
109   // walks it also tracks down the ordinal number of each instruction in the
110   // schedule and store it in the `instruction_schedule` and
111   // 'flattened_instruction_sequence`. The end of each computation is tracked in
112   // `computation_end_time`.
113   int64 FlattenSchedule(const HloComputation& computation, int64 start_time);
114 
115   // Based on the flattened schedule, calculate the start and end of each
116   // buffer.
117   void CalculateBufferStartEndMap();
118 
119   // The aliased buffers could have overlapping live ranges.
120   // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer
121   // has disjoint live range while keeping the live range union the same. This
122   // avoid double counting aliased buffer sizes.
123   //
124   // Before(buffer1 and 2 are aliased):
125   //
126   //           +----+          live range of buffer1
127   //   +------------------+    live range of buffer2
128   //
129   // After:
130   //
131   //           +----------+    live range of buffer1
132   //   +------+                live range of buffer2
133   //
134   // Before(buffer1 and 2 are aliased):
135   //
136   //           +----------+    live range of buffer1
137   //   +------------+          live range of buffer2
138   //
139   // After:
140   //
141   //           +----------+    live range of buffer1
142   //   +------+                live range of buffer2
143   //
144   // Before(buffer1 and 2 are aliased):
145   //
146   //           +----------+    live range of buffer1
147   //   +---+                   live range of buffer2
148   //
149   // After(unchanged):
150   //
151   //           +----------+    live range of buffer1
152   //   +---+                   live range of buffer2
153   //
154   // As another example, imagine we have the following code sequence with live
155   // ranges of each while-aliased buffers:
156   //
157   //                     a      p1    p2    e     b
158   // a = ...             +
159   //                     |
160   // {                   |
161   //   p1 = param        |       +
162   //   ROOT true         |       |
163   // }                   |       +
164   // { // body           |
165   //   p2 = param        +             +
166   //   c = p2 + 1                      +
167   //   d = c + 1
168   //   ROOT e = d + 1                       +
169   // }                                      |
170   //                                        |
171   // b = while (a)                          +     +
172   //                                              |
173   // f = b + 1                                    +
174   //
175   // After normalization it becomes:
176   //
177   //                     a      p1    p2    e     b
178   // a = ...             +
179   //                     |
180   // {                   +
181   //   p1 = param                +
182   //   ROOT true                 |
183   // }                           +
184   // { // body
185   //   p2 = param                      +
186   //   c = p2 + 1                      +
187   //   d = c + 1
188   //   ROOT e = d + 1                       +
189   // }                                      |
190   //                                        |
191   // b = while (a)                          +
192   //                                              +
193   // f = b + 1                                    +
194   //
195   // Note there is no overlap of live ranges after normalization.
196   void NormalizeAliasedBuffers();
197 
198   const HloSchedule& schedule_;
199   const HloAliasAnalysis& alias_analysis_;
200   bool module_scoped_analysis_;
201   bool total_order_scheduled_ = true;
202 
203   HloInstructionSequence flattened_instruction_sequence_;
204   absl::flat_hash_map<const HloInstruction*, int64> instruction_schedule_;
205   absl::flat_hash_map<const HloComputation*, TimeBound> computation_span_times_;
206   absl::flat_hash_map<const HloValue*, TimeBound> buffer_live_ranges_;
207   LogicalTime schedule_end_time_;
208 };
209 
210 }  // namespace xla
211 
212 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
213