1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
17 
18 #include <deque>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/call_graph.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace xla {
37 namespace {
38 
39 using Worklist = std::deque<const HloInstruction*>;
40 using Workset = absl::flat_hash_set<const HloInstruction*>;
41 
AddToWorklist(const HloInstruction * instruction,Worklist * worklist,Workset * workset)42 void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
43                    Workset* workset) {
44   if (!workset->contains(instruction)) {
45     worklist->push_back(instruction);
46     workset->insert(instruction);
47     VLOG(3) << "ADD instruction: " << instruction->name();
48   }
49 }
50 
51 using VisitorFunction = std::function<void(const ShapeIndex& /*index*/)>;
52 
ForEachLiveIndex(const ShapeTree<bool> & index_tree,const VisitorFunction & func)53 void ForEachLiveIndex(const ShapeTree<bool>& index_tree,
54                       const VisitorFunction& func) {
55   index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) {
56     if (live) {
57       func(shape_index);
58     }
59   });
60 }
61 
62 // Marks 'instruction' output live at 'shape_index'.
63 // Adds to 'worklist' iff:
64 // *) 'instruction' is not already on worklist.
65 // *) 'shape_index' has not yet been visited.
MarkLiveAtIndex(const HloInstruction * instruction,const ShapeIndex & shape_index,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)66 void MarkLiveAtIndex(const HloInstruction* instruction,
67                      const ShapeIndex& shape_index,
68                      HloLivenessAnalysis::HloIndexMap* live_index_map,
69                      Worklist* worklist, Workset* workset) {
70   auto it = live_index_map->find(instruction);
71   if (it == live_index_map->end()) {
72     auto it_added = live_index_map->emplace(
73         std::piecewise_construct, std::forward_as_tuple(instruction),
74         std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
75     it = it_added.first;
76   }
77   if (it->second.element(shape_index) == false) {
78     AddToWorklist(instruction, worklist, workset);
79     *it->second.mutable_element(shape_index) = true;
80     VLOG(3) << "MARK instruction: " << instruction->name()
81             << " shape_index: " << shape_index.ToString();
82   }
83 }
84 
85 // Marks 'instruction' live at all shape indices in its output.
MarkLiveAtAllIndices(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)86 void MarkLiveAtAllIndices(const HloInstruction* instruction,
87                           HloLivenessAnalysis::HloIndexMap* live_index_map,
88                           Worklist* worklist, Workset* workset) {
89   bool add_to_worklist = false;
90   auto it = live_index_map->find(instruction);
91   if (it == live_index_map->end()) {
92     live_index_map->emplace(
93         std::piecewise_construct, std::forward_as_tuple(instruction),
94         std::forward_as_tuple(instruction->shape(), /*init_value=*/true));
95     add_to_worklist = true;
96   } else {
97     ShapeUtil::ForEachSubshape(
98         instruction->shape(),
99         [&](const Shape& sub_shape, const ShapeIndex& shape_index) {
100           if (it->second.element(shape_index) == false) {
101             add_to_worklist = true;
102             *it->second.mutable_element(shape_index) = true;
103             VLOG(3) << "MARK instruction: " << instruction->name()
104                     << " shape_index: " << shape_index.ToString();
105           }
106         });
107   }
108   if (add_to_worklist) {
109     AddToWorklist(instruction, worklist, workset);
110   }
111 }
112 
113 // Propagates liveness through Tuple instructions.
114 // *) For each tuple operand:
115 //   *) For tuple output shape index associated with operand:
116 //     *) Propgate live shape indices to tuple operand at the associated
117 //        shape index in the operands output, and add to worklist.
PropagateLivenessThroughTuple(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)118 void PropagateLivenessThroughTuple(
119     const HloInstruction* instruction,
120     HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
121     Workset* workset) {
122   CHECK_EQ(instruction->opcode(), HloOpcode::kTuple);
123   for (int64 operand_index = 0; operand_index < instruction->operand_count();
124        ++operand_index) {
125     const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
126     ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
127       if (shape_index.empty() || shape_index[0] != operand_index) {
128         return;
129       }
130       // Mark top-level index of operand at 'operand_index'.
131       MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map,
132                       worklist, workset);
133       // Mark sub-shape index of operand at 'operand_index'.
134       ShapeIndex operand_shape_index;
135       for (int i = 1; i < shape_index.size(); ++i) {
136         operand_shape_index.push_back(shape_index[i]);
137       }
138       MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index,
139                       live_index_map, worklist, workset);
140     });
141   }
142 }
143 
144 // Propagates liveness through GetTupleElement instructions.
145 // *) For each live index in GetTupleElement output, mark output of GTE operand
146 //    at associated shape index in its output, and add to worklist.
PropagateLivenessThroughGTE(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)147 void PropagateLivenessThroughGTE(
148     const HloInstruction* instruction,
149     HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
150     Workset* workset) {
151   CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement);
152   // Mark operand top-level index.
153   MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist,
154                   workset);
155   const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
156   // Propagate live shape indices along GTE -> Tuple edge.
157   ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
158     ShapeIndex operand_shape_index(shape_index);
159     operand_shape_index.push_front(instruction->tuple_index());
160     MarkLiveAtIndex(instruction->operand(0), operand_shape_index,
161                     live_index_map, worklist, workset);
162   });
163 }
164 
165 // Propagates liveness through While instructions.
166 // *) For each live index in While output, mark shape index of while.body.root
167 //    and while.operand (adding each to worklist).
168 // *) Mark while.cond.root and add to worklist.
PropagateLivenessThroughWhile(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset)169 void PropagateLivenessThroughWhile(
170     const HloInstruction* instruction,
171     HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
172     Workset* workset) {
173   CHECK_EQ(instruction->opcode(), HloOpcode::kWhile);
174   const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
175 
176   ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
177     // Propagate liveness to while body computation root instruction.
178     MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index,
179                     live_index_map, worklist, workset);
180     // Propagate liveness to tuple-shaped operand.
181     MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map,
182                     worklist, workset);
183   });
184 
185   // Propagate liveness to while condition computation root instruction.
186   MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {},
187                   live_index_map, worklist, workset);
188 }
189 
190 // Propagates liveness out of Parameter instructions to callers and aliasing
191 // positions. This can occur if liveness propagates to a parameter in the
192 // while.condition computation, requiring liveness to propagate out to caller
193 // callsite while (and while.body.root).
PropagateLivenessToParameterCallers(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset,CallGraph * call_graph)194 void PropagateLivenessToParameterCallers(
195     const HloInstruction* instruction,
196     HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
197     Workset* workset, CallGraph* call_graph) {
198   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
199   const CallGraphNode& call_graph_node =
200       call_graph->GetNode(instruction->parent());
201   if (call_graph_node.context() == CallContext::kSequential) {
202     for (const CallSite& callsite : call_graph_node.caller_callsites()) {
203       if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
204         auto* xla_while = callsite.instruction();
205         const ShapeTree<bool>& index_tree =
206             FindOrDie(*live_index_map, instruction);
207         ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
208           // Propagate liveness to while result{shape_index}
209           MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist,
210                           workset);
211           // Propagate liveness to while body root{shape_index}.
212           MarkLiveAtIndex(xla_while->while_body()->root_instruction(),
213                           shape_index, live_index_map, worklist, workset);
214           // Propagate liveness to operand(0){shape_index}.
215           MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map,
216                           worklist, workset);
217         });
218       }
219     }
220   }
221 }
222 
223 // Makes sure that if a live instruction is within a computation used in control
224 // flow operations, we mark live even other related instructions.
PropagateLivenessThroughControlFlow(const HloInstruction * instruction,HloLivenessAnalysis::HloIndexMap * live_index_map,Worklist * worklist,Workset * workset,CallGraph * call_graph)225 void PropagateLivenessThroughControlFlow(
226     const HloInstruction* instruction,
227     HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
228     Workset* workset, CallGraph* call_graph) {
229   const CallGraphNode& call_graph_node =
230       call_graph->GetNode(instruction->parent());
231   if (call_graph_node.context() == CallContext::kSequential) {
232     for (const CallSite& callsite : call_graph_node.caller_callsites()) {
233       HloInstruction* caller = callsite.instruction();
234       if (caller->opcode() == HloOpcode::kWhile) {
235         // If a live instruction is within the %while body or condition
236         // computation, mark the predicate value returned by the condition
237         // computation live as well.
238         MarkLiveAtIndex(caller->while_condition()->root_instruction(), {},
239                         live_index_map, worklist, workset);
240       } else if (caller->opcode() == HloOpcode::kConditional) {
241         // If a live instruction is within the true or false branches of a
242         // conditional, we mark the predicate operand live as well.
243         MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist,
244                         workset);
245       }
246     }
247   }
248 }
249 
250 }  // namespace
251 
HloLivenessAnalysis(const HloModule & module)252 HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
253     : module_(module), call_graph_(CallGraph::Build(&module)) {}
254 
255 // Runs liveness analysis on 'module_'.
256 // Initializes worklist with entry root instruction (and any instruction with
257 // side-effects), marking all of their output shape indices live.
258 // Visits elements on worklist, propagating liveness from an instructions
259 // live output shape indices to its called computations and operands.
RunAnalysis()260 void HloLivenessAnalysis::RunAnalysis() {
261   Worklist worklist;
262   Workset workset;
263   // Add entry compuation root instruction.
264   MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(),
265                        &live_index_map_, &worklist, &workset);
266   for (auto* computation : module_.computations()) {
267     for (auto* instruction : computation->instructions()) {
268       if (instruction->HasSideEffectNoRecurse()) {
269         // Add instructions with side effects.
270         MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist,
271                              &workset);
272       }
273     }
274   }
275 
276   while (!worklist.empty()) {
277     const HloInstruction* instruction = worklist.front();
278     worklist.pop_front();
279     workset.erase(workset.find(instruction));
280     VLOG(1) << "VISIT instruction: " << instruction->name();
281 
282     if (instruction->opcode() == HloOpcode::kTuple) {
283       PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist,
284                                     &workset);
285     } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
286       PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
287                                   &workset);
288     } else if (instruction->opcode() == HloOpcode::kWhile) {
289       PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
290                                     &workset);
291     } else if (instruction->opcode() == HloOpcode::kParameter) {
292       PropagateLivenessToParameterCallers(instruction, &live_index_map_,
293                                           &worklist, &workset,
294                                           call_graph_.get());
295     } else {
296       // Propagate liveness to called computations.
297       for (auto* called_computation : instruction->called_computations()) {
298         MarkLiveAtAllIndices(called_computation->root_instruction(),
299                              &live_index_map_, &worklist, &workset);
300       }
301       // Propagate liveness to operands.
302       for (HloInstruction* operand : instruction->operands()) {
303         MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
304       }
305     }
306     PropagateLivenessThroughControlFlow(instruction, &live_index_map_,
307                                         &worklist, &workset, call_graph_.get());
308   }
309 }
310 
IsLive(const HloInstruction * instruction,const ShapeIndex & shape_index) const311 bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction,
312                                  const ShapeIndex& shape_index) const {
313   if (ContainsKey(live_index_map_, instruction)) {
314     return FindOrDie(live_index_map_, instruction).element(shape_index);
315   }
316   return false;
317 }
318 
319 /* static */
Run(const HloModule & module)320 StatusOr<std::unique_ptr<HloLivenessAnalysis>> HloLivenessAnalysis::Run(
321     const HloModule& module) {
322   VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name();
323   XLA_VLOG_LINES(2, module.ToString());
324 
325   auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module));
326 
327   liveness_analysis->RunAnalysis();
328 
329   return std::move(liveness_analysis);
330 }
331 
332 }  // namespace xla
333