1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17 
18 #include <queue>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace xla {
33 
34 using absl::StrAppendFormat;
35 using absl::StrCat;
36 
CallContextToString(CallContext context)37 string CallContextToString(CallContext context) {
38   switch (context) {
39     case CallContext::kNone:
40       return "kNone";
41     case CallContext::kSequential:
42       return "kSequential";
43     case CallContext::kParallel:
44       return "kParallel";
45     case CallContext::kBoth:
46       return "kBoth";
47   }
48 }
49 
operator <<(std::ostream & out,const CallContext & context)50 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
51   out << CallContextToString(context);
52   return out;
53 }
54 
GetInstructionCallContext(HloOpcode opcode)55 CallContext GetInstructionCallContext(HloOpcode opcode) {
56   switch (opcode) {
57     case HloOpcode::kCall:
58     case HloOpcode::kConditional:
59     case HloOpcode::kWhile:
60       return CallContext::kSequential;
61     case HloOpcode::kAllReduce:
62     case HloOpcode::kMap:
63     case HloOpcode::kReduce:
64     case HloOpcode::kReduceWindow:
65     case HloOpcode::kScatter:
66     case HloOpcode::kSelectAndScatter:
67     case HloOpcode::kSort:
68     case HloOpcode::kFusion:
69       return CallContext::kParallel;
70     default:
71       return CallContext::kNone;
72   }
73 }
74 
ToString() const75 string CallSite::ToString() const {
76   return StrCat(
77       instruction()->name(), " calls in context ",
78       CallContextToString(context()), ": ",
79       absl::StrJoin(called_computations(), ", ",
80                     [](string* out, const HloComputation* computation) {
81                       out->append(computation->name());
82                     }));
83 }
84 
CallGraphNode(HloComputation * computation)85 CallGraphNode::CallGraphNode(HloComputation* computation)
86     : computation_(computation) {}
87 
GetCallSite(const HloInstruction * instruction) const88 const CallSite* CallGraphNode::GetCallSite(
89     const HloInstruction* instruction) const {
90   auto it = callsite_instructions_.find(instruction);
91   if (it == callsite_instructions_.end()) {
92     return nullptr;
93   }
94   return &callsites_[it->second];
95 }
96 
AddCallerCallSite(const CallSite & caller_callsite)97 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
98   caller_callsites_.push_back(caller_callsite);
99   HloComputation* caller = caller_callsite.instruction()->parent();
100   if (!ContainsKey(caller_set_, caller)) {
101     callers_.push_back(caller);
102     caller_set_.insert(caller);
103   }
104 }
105 
AddCallSiteForInstruction(HloInstruction * instruction)106 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
107   CHECK_EQ(instruction->parent(), computation());
108   const CallContext context = GetInstructionCallContext(instruction->opcode());
109   if (!instruction->called_computations().empty()) {
110     CHECK(context == CallContext::kSequential ||
111           context == CallContext::kParallel);
112     callsite_instructions_.insert({instruction, callsites_.size()});
113     callsites_.push_back(
114         CallSite(instruction, instruction->called_computations(), context));
115     // Update callee computations to include any new computations called by this
116     // instruction.
117     for (auto* callee : callsites_.back().called_computations()) {
118       if (!ContainsKey(callee_set_, callee)) {
119         callees_.push_back(callee);
120         callee_set_.insert(callee);
121       }
122     }
123   }
124 }
125 
CallGraph(const HloModule * module)126 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
127 
GetNode(const HloComputation * computation) const128 const CallGraphNode& CallGraph::GetNode(
129     const HloComputation* computation) const {
130   auto it = node_indices_.find(computation);
131   CHECK(it != node_indices_.end());
132   return nodes_[it->second];
133 }
134 
GetNode(const HloComputation * computation)135 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
136   auto it = node_indices_.find(computation);
137   CHECK(it != node_indices_.end());
138   return nodes_[it->second];
139 }
140 
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const141 bool CallGraph::DominatesHelper(
142     const HloComputation* a, const HloComputation* b,
143     absl::flat_hash_set<const HloComputation*>* visited) const {
144   if (a == b || ContainsKey(*visited, b)) {
145     // The call graph is guaranteed to be acyclic so any previously visited node
146     // we encounter was already determined to be dominated.
147     return true;
148   }
149 
150   const CallGraphNode& b_node = GetNode(b);
151   if (b_node.callers().empty()) {
152     // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
153     return false;
154   }
155 
156   // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
157   visited->insert(b);
158   for (const HloComputation* b_caller : b_node.callers()) {
159     if (!DominatesHelper(a, b_caller, visited)) {
160       return false;
161     }
162   }
163   return true;
164 }
165 
Dominates(const HloComputation * a,const HloComputation * b) const166 bool CallGraph::Dominates(const HloComputation* a,
167                           const HloComputation* b) const {
168   absl::flat_hash_set<const HloComputation*> visited;
169   return DominatesHelper(a, b, &visited);
170 }
171 
172 namespace {
173 
174 // Returns the call context of a computation which is called from contexts 'a'
175 // and 'b'.
UnionContexts(CallContext a,CallContext b)176 CallContext UnionContexts(CallContext a, CallContext b) {
177   if (a == CallContext::kNone) {
178     return b;
179   } else if (b == CallContext::kNone) {
180     return a;
181   } else if (a == b) {
182     return a;
183   } else {
184     // Contexts are different and neither is kNone, ie one is kSequential and
185     // the other is kParallel.
186     return CallContext::kBoth;
187   }
188 }
189 
190 }  // namespace
191 
SetCallContexts()192 void CallGraph::SetCallContexts() {
193   std::queue<CallGraphNode*> worklist;
194 
195   // Initialize worklist with all roots of the call graph (computations without
196   // callers).
197   for (const HloComputation* computation : module_->computations()) {
198     CallGraphNode& node = GetNode(computation);
199     if (node.callers().empty()) {
200       node.set_context(CallContext::kSequential);
201       worklist.push(&node);
202     }
203   }
204 
205   while (!worklist.empty()) {
206     CallGraphNode* node = worklist.front();
207     worklist.pop();
208 
209     for (const CallSite& callsite : node->callsites()) {
210       for (const HloComputation* callee : callsite.called_computations()) {
211         CallGraphNode& callee_node = GetNode(callee);
212 
213         // Update context of callee computation based on the callsite and its
214         // current context.
215         CallContext context_to_add;
216         if (callsite.context() == CallContext::kParallel) {
217           context_to_add = CallContext::kParallel;
218         } else {
219           CHECK_EQ(callsite.context(), CallContext::kSequential);
220           context_to_add = node->context();
221         }
222         CallContext new_context =
223             UnionContexts(context_to_add, callee_node.context());
224 
225         if (new_context != callee_node.context()) {
226           // Context of computation has been changed so add node to worklist.
227           callee_node.set_context(new_context);
228           worklist.push(&callee_node);
229         }
230       }
231     }
232   }
233 
234   // No node should have a kNone calling context.
235   for (const HloComputation* computation : module_->computations()) {
236     CHECK_NE(GetNode(computation).context(), CallContext::kNone);
237   }
238 }
239 
SetNodeDepths()240 void CallGraph::SetNodeDepths() {
241   std::queue<CallGraphNode*> worklist;
242 
243   // Initialize node depths to -1.
244   for (CallGraphNode& node : nodes_) {
245     node.set_depth(-1);
246   }
247 
248   // Initialize worklist with all roots of the call graph (computations without
249   // callers).
250   for (const HloComputation* computation : module_->computations()) {
251     CallGraphNode& node = GetNode(computation);
252     if (node.callers().empty()) {
253       node.set_depth(0);
254       worklist.push(&node);
255     }
256   }
257 
258   while (!worklist.empty()) {
259     CallGraphNode* node = worklist.front();
260     worklist.pop();
261     for (const HloComputation* callee : node->callees()) {
262       CallGraphNode& callee_node = GetNode(callee);
263       if (callee_node.depth() < node->depth() + 1) {
264         callee_node.set_depth(node->depth() + 1);
265         worklist.push(&callee_node);
266       }
267     }
268   }
269 
270   for (CallGraphNode& node : nodes_) {
271     CHECK_NE(node.depth(), -1);
272   }
273 }
274 
275 /* static */
Build(const HloModule * module)276 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
277   // Constructor for CallGraph is private so absl::make_unique can't be used.
278   auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
279 
280   VLOG(2) << "Building call graph for:";
281   XLA_VLOG_LINES(2, module->ToString());
282 
283   // Construct nodes of the call graph and populate the callsites.
284   for (HloComputation* computation : module->computations()) {
285     auto it_added = call_graph->node_indices_.insert(
286         {computation, call_graph->nodes_.size()});
287     // All computations should be unique, so the computation should not already
288     // exist in the map.
289     CHECK(it_added.second);
290     call_graph->nodes_.emplace_back(computation);
291 
292     // Add all callsites in this computation.
293     for (HloInstruction* instruction : computation->instructions()) {
294       call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
295     }
296   }
297 
298   // Add caller callsites to each node.
299   for (const HloComputation* computation : module->computations()) {
300     for (const CallSite& callsite :
301          call_graph->GetNode(computation).callsites()) {
302       for (auto* callee : callsite.called_computations()) {
303         // Add caller callsites.
304         call_graph->GetNode(callee).AddCallerCallSite(callsite);
305       }
306     }
307   }
308 
309   call_graph->SetCallContexts();
310   call_graph->SetNodeDepths();
311 
312   XLA_VLOG_LINES(1, call_graph->ToString());
313 
314   return call_graph;
315 }
316 
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const317 Status CallGraph::VisitNodesInternal(
318     const VisitorFunction& visitor_func, const CallGraphNode& node,
319     absl::flat_hash_set<const CallGraphNode*>* visited) const {
320   auto pair = visited->insert(&node);
321   if (!pair.second) {
322     // Node was not inserted. Node has already been visited.
323     return Status::OK();
324   }
325 
326   for (const HloComputation* computation : node.callees()) {
327     TF_RETURN_IF_ERROR(
328         VisitNodesInternal(visitor_func, GetNode(computation), visited));
329   }
330 
331   return visitor_func(node);
332 }
333 
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const334 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
335                              bool visit_unreachable_nodes) const {
336   absl::flat_hash_set<const CallGraphNode*> visited;
337   if (visit_unreachable_nodes) {
338     // Traverse from all roots in the call graph.
339     for (const CallGraphNode& node : nodes()) {
340       if (node.callers().empty()) {
341         TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
342       }
343     }
344   } else {
345     // Traverse only from the entry computation.
346     TF_RETURN_IF_ERROR(VisitNodesInternal(
347         visitor_func, GetNode(module_->entry_computation()), &visited));
348   }
349 
350   return Status::OK();
351 }
352 
IsFlattened() const353 bool CallGraph::IsFlattened() const {
354   for (const CallGraphNode& node : nodes_) {
355     if (node.context() == CallContext::kBoth) {
356       return false;
357     }
358     if (node.context() == CallContext::kSequential &&
359         node.caller_callsites().size() > 1) {
360       return false;
361     }
362   }
363   return true;
364 }
365 
GetComputationCallers(HloComputation * c)366 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
367     HloComputation* c) {
368   std::vector<HloInstruction*> callers;
369   for (auto callsite : GetNode(c).caller_callsites()) {
370     callers.push_back(callsite.instruction());
371   }
372   return callers;
373 }
374 
375 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const376 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
377                                              HloInstruction* b) const {
378   // Lambda which returns the next instruction in the callee->caller chain in
379   // the call graph. This is the unique instruction which calls the computation
380   // containing 'instruction'. If more than one instruction calls the
381   // computation containing 'instruction' or no instructions call the
382   // computation then nullptr is returned.
383   auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
384     const CallGraphNode& node = GetNode(instruction->parent());
385     if (node.caller_callsites().size() != 1) {
386       return nullptr;
387     }
388     return node.caller_callsites()[0].instruction();
389   };
390 
391   // Iterate through the callee->caller chains and find the earliest common
392   // element.
393   HloInstruction* a_ancestor = a;
394   HloInstruction* b_ancestor = b;
395   int a_depth = GetNode(a->parent()).depth();
396   int b_depth = GetNode(b->parent()).depth();
397 
398   // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
399   // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
400   // reduces the depth by exactly one.
401   if (a_depth > b_depth) {
402     for (int i = 0; i < a_depth - b_depth; ++i) {
403       a_ancestor = next_caller(a_ancestor);
404       if (a_ancestor == nullptr) {
405         return {nullptr, nullptr};
406       }
407     }
408   } else if (b_depth > a_depth) {
409     for (int i = 0; i < b_depth - a_depth; ++i) {
410       b_ancestor = next_caller(b_ancestor);
411       if (b_ancestor == nullptr) {
412         return {nullptr, nullptr};
413       }
414     }
415   }
416 
417   while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
418     if (a_ancestor->parent() == b_ancestor->parent()) {
419       return {a_ancestor, b_ancestor};
420     }
421 
422     a_ancestor = next_caller(a_ancestor);
423     b_ancestor = next_caller(b_ancestor);
424   }
425   return {nullptr, nullptr};
426 }
427 
ToString() const428 string CallGraph::ToString() const {
429   string out;
430   StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
431   for (const CallGraphNode& node : nodes()) {
432     StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
433     StrAppendFormat(&out, "  calls:\n");
434     for (const HloComputation* callee : node.callees()) {
435       StrAppendFormat(&out, "    %s\n", callee->name());
436     }
437     StrAppendFormat(&out, "  called by:\n");
438     for (const HloComputation* caller : node.callers()) {
439       StrAppendFormat(&out, "    %s\n", caller->name());
440     }
441     StrAppendFormat(&out, "  callsites:\n");
442     for (const CallSite& callsite : node.callsites()) {
443       StrAppendFormat(&out, "    %s\n", callsite.ToString());
444     }
445   }
446   return out;
447 }
448 
449 }  // namespace xla
450