1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
17
18 #include <deque>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/call_graph.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace xla {
37 namespace {
38
39 using Worklist = std::deque<const HloInstruction*>;
40 using Workset = absl::flat_hash_set<const HloInstruction*>;
41
AddToWorklist(const HloInstruction * instruction,Worklist * worklist,Workset * workset)42 void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
43 Workset* workset) {
44 if (!workset->contains(instruction)) {
45 worklist->push_back(instruction);
46 workset->insert(instruction);
47 VLOG(3) << "ADD instruction: " << instruction->name();
48 }
49 }
50
51 using VisitorFunction = std::function<void(const ShapeIndex& /*index*/)>;
52
ForEachLiveIndex(const ShapeTree<bool> & index_tree,const VisitorFunction & func)53 void ForEachLiveIndex(const ShapeTree<bool>& index_tree,
54 const VisitorFunction& func) {
55 index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) {
56 if (live) {
57 func(shape_index);
58 }
59 });
60 }
61
62 // Marks 'instruction' output live at 'shape_index'.
63 // Adds to 'worklist' iff:
64 // *) 'instruction' is not already on worklist.
65 // *) 'shape_index' has not yet been visited.
MarkLiveAtIndex(const HloInstruction * instruction,const ShapeIndex & shape_index,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)66 void MarkLiveAtIndex(const HloInstruction* instruction,
67 const ShapeIndex& shape_index,
68 HloLivenessAnalysis::HloIndexMap* live_index_map,
69 Worklist* worklist, Workset* workset) {
70 auto it = live_index_map->find(instruction);
71 if (it == live_index_map->end()) {
72 auto it_added = live_index_map->emplace(
73 std::piecewise_construct, std::forward_as_tuple(instruction),
74 std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
75 it = it_added.first;
76 }
77 if (it->second.element(shape_index) == false) {
78 AddToWorklist(instruction, worklist, workset);
79 *it->second.mutable_element(shape_index) = true;
80 VLOG(3) << "MARK instruction: " << instruction->name()
81 << " shape_index: " << shape_index.ToString();
82 }
83 }
84
85 // Marks 'instruction' live at all shape indices in its output.
MarkLiveAtAllIndices(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)86 void MarkLiveAtAllIndices(const HloInstruction* instruction,
87 HloLivenessAnalysis::HloIndexMap* live_index_map,
88 Worklist* worklist, Workset* workset) {
89 bool add_to_worklist = false;
90 auto it = live_index_map->find(instruction);
91 if (it == live_index_map->end()) {
92 live_index_map->emplace(
93 std::piecewise_construct, std::forward_as_tuple(instruction),
94 std::forward_as_tuple(instruction->shape(), /*init_value=*/true));
95 add_to_worklist = true;
96 } else {
97 ShapeUtil::ForEachSubshape(
98 instruction->shape(),
99 [&](const Shape& sub_shape, const ShapeIndex& shape_index) {
100 if (it->second.element(shape_index) == false) {
101 add_to_worklist = true;
102 *it->second.mutable_element(shape_index) = true;
103 VLOG(3) << "MARK instruction: " << instruction->name()
104 << " shape_index: " << shape_index.ToString();
105 }
106 });
107 }
108 if (add_to_worklist) {
109 AddToWorklist(instruction, worklist, workset);
110 }
111 }
112
113 // Propagates liveness through Tuple instructions.
114 // *) For each tuple operand:
115 // *) For tuple output shape index associated with operand:
116 // *) Propgate live shape indices to tuple operand at the associated
117 // shape index in the operands output, and add to worklist.
PropagateLivenessThroughTuple(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)118 void PropagateLivenessThroughTuple(
119 const HloInstruction* instruction,
120 HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
121 Workset* workset) {
122 CHECK_EQ(instruction->opcode(), HloOpcode::kTuple);
123 for (int64 operand_index = 0; operand_index < instruction->operand_count();
124 ++operand_index) {
125 const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
126 ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
127 if (shape_index.empty() || shape_index[0] != operand_index) {
128 return;
129 }
130 // Mark top-level index of operand at 'operand_index'.
131 MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map,
132 worklist, workset);
133 // Mark sub-shape index of operand at 'operand_index'.
134 ShapeIndex operand_shape_index;
135 for (int i = 1; i < shape_index.size(); ++i) {
136 operand_shape_index.push_back(shape_index[i]);
137 }
138 MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index,
139 live_index_map, worklist, workset);
140 });
141 }
142 }
143
144 // Propagates liveness through GetTupleElement instructions.
145 // *) For each live index in GetTupleElement output, mark output of GTE operand
146 // at associated shape index in its output, and add to worklist.
PropagateLivenessThroughGTE(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)147 void PropagateLivenessThroughGTE(
148 const HloInstruction* instruction,
149 HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
150 Workset* workset) {
151 CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement);
152 // Mark operand top-level index.
153 MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist,
154 workset);
155 const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
156 // Propagate live shape indices along GTE -> Tuple edge.
157 ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
158 ShapeIndex operand_shape_index(shape_index);
159 operand_shape_index.push_front(instruction->tuple_index());
160 MarkLiveAtIndex(instruction->operand(0), operand_shape_index,
161 live_index_map, worklist, workset);
162 });
163 }
164
165 // Propagates liveness through While instructions.
166 // *) For each live index in While output, mark shape index of while.body.root
167 // and while.operand (adding each to worklist).
168 // *) Mark while.cond.root and add to worklist.
PropagateLivenessThroughWhile(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)169 void PropagateLivenessThroughWhile(
170 const HloInstruction* instruction,
171 HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
172 Workset* workset) {
173 CHECK_EQ(instruction->opcode(), HloOpcode::kWhile);
174 const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
175
176 ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
177 // Propagate liveness to while body computation root instruction.
178 MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index,
179 live_index_map, worklist, workset);
180 // Propagate liveness to tuple-shaped operand.
181 MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map,
182 worklist, workset);
183 });
184
185 // Propagate liveness to while condition computation root instruction.
186 MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {},
187 live_index_map, worklist, workset);
188 }
189
190 // Propagates liveness out of Parameter instructions to callers and aliasing
191 // positions. This can occur if liveness propagates to a parameter in the
192 // while.condition computation, requiring liveness to propagate out to caller
193 // callsite while (and while.body.root).
PropagateLivenessToParameterCallers(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset,CallGraph * call_graph)194 void PropagateLivenessToParameterCallers(
195 const HloInstruction* instruction,
196 HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
197 Workset* workset, CallGraph* call_graph) {
198 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
199 const CallGraphNode& call_graph_node =
200 call_graph->GetNode(instruction->parent());
201 if (call_graph_node.context() == CallContext::kSequential) {
202 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
203 if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
204 auto* xla_while = callsite.instruction();
205 const ShapeTree<bool>& index_tree =
206 FindOrDie(*live_index_map, instruction);
207 ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
208 // Propagate liveness to while result{shape_index}
209 MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist,
210 workset);
211 // Propagate liveness to while body root{shape_index}.
212 MarkLiveAtIndex(xla_while->while_body()->root_instruction(),
213 shape_index, live_index_map, worklist, workset);
214 // Propagate liveness to operand(0){shape_index}.
215 MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map,
216 worklist, workset);
217 });
218 }
219 }
220 }
221 }
222
223 // Makes sure that if a live instruction is within a computation used in control
224 // flow operations, we mark live even other related instructions.
PropagateLivenessThroughControlFlow(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset,CallGraph * call_graph)225 void PropagateLivenessThroughControlFlow(
226 const HloInstruction* instruction,
227 HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
228 Workset* workset, CallGraph* call_graph) {
229 const CallGraphNode& call_graph_node =
230 call_graph->GetNode(instruction->parent());
231 if (call_graph_node.context() == CallContext::kSequential) {
232 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
233 HloInstruction* caller = callsite.instruction();
234 if (caller->opcode() == HloOpcode::kWhile) {
235 // If a live instruction is within the %while body or condition
236 // computation, mark the predicate value returned by the condition
237 // computation live as well.
238 MarkLiveAtIndex(caller->while_condition()->root_instruction(), {},
239 live_index_map, worklist, workset);
240 } else if (caller->opcode() == HloOpcode::kConditional) {
241 // If a live instruction is within the true or false branches of a
242 // conditional, we mark the predicate operand live as well.
243 MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist,
244 workset);
245 }
246 }
247 }
248 }
249
250 } // namespace
251
HloLivenessAnalysis(const HloModule & module)252 HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
253 : module_(module), call_graph_(CallGraph::Build(&module)) {}
254
255 // Runs liveness analysis on 'module_'.
256 // Initializes worklist with entry root instruction (and any instruction with
257 // side-effects), marking all of their output shape indices live.
258 // Visits elements on worklist, propagating liveness from an instructions
259 // live output shape indices to its called computations and operands.
RunAnalysis()260 void HloLivenessAnalysis::RunAnalysis() {
261 Worklist worklist;
262 Workset workset;
263 // Add entry compuation root instruction.
264 MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(),
265 &live_index_map_, &worklist, &workset);
266 for (auto* computation : module_.computations()) {
267 for (auto* instruction : computation->instructions()) {
268 if (instruction->HasSideEffectNoRecurse()) {
269 // Add instructions with side effects.
270 MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist,
271 &workset);
272 }
273 }
274 }
275
276 while (!worklist.empty()) {
277 const HloInstruction* instruction = worklist.front();
278 worklist.pop_front();
279 workset.erase(workset.find(instruction));
280 VLOG(1) << "VISIT instruction: " << instruction->name();
281
282 if (instruction->opcode() == HloOpcode::kTuple) {
283 PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist,
284 &workset);
285 } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
286 PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
287 &workset);
288 } else if (instruction->opcode() == HloOpcode::kWhile) {
289 PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
290 &workset);
291 } else if (instruction->opcode() == HloOpcode::kParameter) {
292 PropagateLivenessToParameterCallers(instruction, &live_index_map_,
293 &worklist, &workset,
294 call_graph_.get());
295 } else {
296 // Propagate liveness to called computations.
297 for (auto* called_computation : instruction->called_computations()) {
298 MarkLiveAtAllIndices(called_computation->root_instruction(),
299 &live_index_map_, &worklist, &workset);
300 }
301 // Propagate liveness to operands.
302 for (HloInstruction* operand : instruction->operands()) {
303 MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
304 }
305 }
306 PropagateLivenessThroughControlFlow(instruction, &live_index_map_,
307 &worklist, &workset, call_graph_.get());
308 }
309 }
310
IsLive(const HloInstruction * instruction,const ShapeIndex & shape_index) const311 bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction,
312 const ShapeIndex& shape_index) const {
313 if (ContainsKey(live_index_map_, instruction)) {
314 return FindOrDie(live_index_map_, instruction).element(shape_index);
315 }
316 return false;
317 }
318
319 /* static */
Run(const HloModule & module)320 StatusOr<std::unique_ptr<HloLivenessAnalysis>> HloLivenessAnalysis::Run(
321 const HloModule& module) {
322 VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name();
323 XLA_VLOG_LINES(2, module.ToString());
324
325 auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module));
326
327 liveness_analysis->RunAnalysis();
328
329 return std::move(liveness_analysis);
330 }
331
332 } // namespace xla
333