1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17 
18 #include <queue>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace xla {
33 
34 using absl::StrAppendFormat;
35 using absl::StrCat;
36 
CallContextToString(CallContext context)37 string CallContextToString(CallContext context) {
38   switch (context) {
39     case CallContext::kNone:
40       return "kNone";
41     case CallContext::kSequential:
42       return "kSequential";
43     case CallContext::kParallel:
44       return "kParallel";
45     case CallContext::kBoth:
46       return "kBoth";
47   }
48 }
49 
operator <<(std::ostream & out,const CallContext & context)50 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
51   out << CallContextToString(context);
52   return out;
53 }
54 
GetInstructionCallContext(HloOpcode opcode)55 CallContext GetInstructionCallContext(HloOpcode opcode) {
56   switch (opcode) {
57     case HloOpcode::kCall:
58     case HloOpcode::kConditional:
59     case HloOpcode::kWhile:
60       return CallContext::kSequential;
61     case HloOpcode::kAllReduce:
62     case HloOpcode::kMap:
63     case HloOpcode::kReduce:
64     case HloOpcode::kReduceWindow:
65     case HloOpcode::kScatter:
66     case HloOpcode::kSelectAndScatter:
67     case HloOpcode::kSort:
68     case HloOpcode::kFusion:
69     case HloOpcode::kCustomCall:
70       return CallContext::kParallel;
71     default:
72       return CallContext::kNone;
73   }
74 }
75 
ToString() const76 string CallSite::ToString() const {
77   return StrCat(
78       instruction()->name(), " calls in context ",
79       CallContextToString(context()), ": ",
80       absl::StrJoin(called_computations(), ", ",
81                     [](string* out, const HloComputation* computation) {
82                       out->append(computation->name());
83                     }));
84 }
85 
CallGraphNode(HloComputation * computation)86 CallGraphNode::CallGraphNode(HloComputation* computation)
87     : computation_(computation) {}
88 
GetCallSite(const HloInstruction * instruction) const89 const CallSite* CallGraphNode::GetCallSite(
90     const HloInstruction* instruction) const {
91   auto it = callsite_instructions_.find(instruction);
92   if (it == callsite_instructions_.end()) {
93     return nullptr;
94   }
95   return &callsites_[it->second];
96 }
97 
ToString() const98 std::string CallGraphNode::ToString() const { return computation_->name(); }
99 
AddCallerCallSite(const CallSite & caller_callsite)100 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
101   caller_callsites_.push_back(caller_callsite);
102   HloComputation* caller = caller_callsite.instruction()->parent();
103   if (!ContainsKey(caller_set_, caller)) {
104     callers_.push_back(caller);
105     caller_set_.insert(caller);
106   }
107 }
108 
AddCallSiteForInstruction(HloInstruction * instruction)109 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
110   CHECK_EQ(instruction->parent(), computation());
111   const CallContext context = GetInstructionCallContext(instruction->opcode());
112   if (!instruction->called_computations().empty()) {
113     CHECK(context == CallContext::kSequential ||
114           context == CallContext::kParallel);
115     callsite_instructions_.insert({instruction, callsites_.size()});
116     callsites_.push_back(
117         CallSite(instruction, instruction->called_computations(), context));
118     // Update callee computations to include any new computations called by this
119     // instruction.
120     for (auto* callee : callsites_.back().called_computations()) {
121       if (!ContainsKey(callee_set_, callee)) {
122         callees_.push_back(callee);
123         callee_set_.insert(callee);
124       }
125     }
126   }
127 }
128 
CallGraph(const HloModule * module)129 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
130 
GetNode(const HloComputation * computation) const131 const CallGraphNode& CallGraph::GetNode(
132     const HloComputation* computation) const {
133   auto it = node_indices_.find(computation);
134   CHECK(it != node_indices_.end());
135   return nodes_[it->second];
136 }
137 
GetNode(const HloComputation * computation)138 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
139   auto it = node_indices_.find(computation);
140   CHECK(it != node_indices_.end());
141   return nodes_[it->second];
142 }
143 
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const144 bool CallGraph::DominatesHelper(
145     const HloComputation* a, const HloComputation* b,
146     absl::flat_hash_set<const HloComputation*>* visited) const {
147   if (a == b || ContainsKey(*visited, b)) {
148     // The call graph is guaranteed to be acyclic so any previously visited node
149     // we encounter was already determined to be dominated.
150     return true;
151   }
152 
153   const CallGraphNode& b_node = GetNode(b);
154   if (b_node.callers().empty()) {
155     // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
156     return false;
157   }
158 
159   // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
160   visited->insert(b);
161   for (const HloComputation* b_caller : b_node.callers()) {
162     if (!DominatesHelper(a, b_caller, visited)) {
163       return false;
164     }
165   }
166   return true;
167 }
168 
Dominates(const HloComputation * a,const HloComputation * b) const169 bool CallGraph::Dominates(const HloComputation* a,
170                           const HloComputation* b) const {
171   absl::flat_hash_set<const HloComputation*> visited;
172   return DominatesHelper(a, b, &visited);
173 }
174 
175 namespace {
176 
177 // Returns the call context of a computation which is called from contexts 'a'
178 // and 'b'.
UnionContexts(CallContext a,CallContext b)179 CallContext UnionContexts(CallContext a, CallContext b) {
180   if (a == CallContext::kNone) {
181     return b;
182   } else if (b == CallContext::kNone) {
183     return a;
184   } else if (a == b) {
185     return a;
186   } else {
187     // Contexts are different and neither is kNone, ie one is kSequential and
188     // the other is kParallel.
189     return CallContext::kBoth;
190   }
191 }
192 
193 }  // namespace
194 
SetCallContexts()195 void CallGraph::SetCallContexts() {
196   std::queue<CallGraphNode*> worklist;
197 
198   // Initialize worklist with all roots of the call graph (computations without
199   // callers).
200   for (const HloComputation* computation : module_->computations()) {
201     CallGraphNode& node = GetNode(computation);
202     if (node.callers().empty()) {
203       node.set_context(CallContext::kSequential);
204       worklist.push(&node);
205     }
206   }
207 
208   while (!worklist.empty()) {
209     CallGraphNode* node = worklist.front();
210     worklist.pop();
211 
212     for (const CallSite& callsite : node->callsites()) {
213       for (const HloComputation* callee : callsite.called_computations()) {
214         CallGraphNode& callee_node = GetNode(callee);
215 
216         // Update context of callee computation based on the callsite and its
217         // current context.
218         CallContext context_to_add;
219         if (callsite.context() == CallContext::kParallel) {
220           context_to_add = CallContext::kParallel;
221         } else {
222           CHECK_EQ(callsite.context(), CallContext::kSequential);
223           context_to_add = node->context();
224         }
225         CallContext new_context =
226             UnionContexts(context_to_add, callee_node.context());
227 
228         if (new_context != callee_node.context()) {
229           // Context of computation has been changed so add node to worklist.
230           callee_node.set_context(new_context);
231           worklist.push(&callee_node);
232         }
233       }
234     }
235   }
236 
237   // No node should have a kNone calling context.
238   for (const HloComputation* computation : module_->computations()) {
239     CHECK_NE(GetNode(computation).context(), CallContext::kNone);
240   }
241 }
242 
SetNodeDepths()243 void CallGraph::SetNodeDepths() {
244   std::queue<CallGraphNode*> worklist;
245 
246   // Initialize node depths to -1.
247   for (CallGraphNode& node : nodes_) {
248     node.set_depth(-1);
249   }
250 
251   // Initialize worklist with all roots of the call graph (computations without
252   // callers).
253   for (const HloComputation* computation : module_->computations()) {
254     CallGraphNode& node = GetNode(computation);
255     if (node.callers().empty()) {
256       node.set_depth(0);
257       worklist.push(&node);
258     }
259   }
260 
261   while (!worklist.empty()) {
262     CallGraphNode* node = worklist.front();
263     worklist.pop();
264     for (const HloComputation* callee : node->callees()) {
265       CallGraphNode& callee_node = GetNode(callee);
266       if (callee_node.depth() < node->depth() + 1) {
267         callee_node.set_depth(node->depth() + 1);
268         worklist.push(&callee_node);
269       }
270     }
271   }
272 
273   for (CallGraphNode& node : nodes_) {
274     CHECK_NE(node.depth(), -1);
275   }
276 }
277 
278 /* static */
Build(const HloModule * module)279 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
280   // Constructor for CallGraph is private so absl::make_unique can't be used.
281   auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
282 
283   VLOG(3) << "Building call graph for:";
284   XLA_VLOG_LINES(3, module->ToString());
285 
286   // Construct nodes of the call graph and populate the callsites.
287   for (HloComputation* computation : module->computations()) {
288     auto it_added = call_graph->node_indices_.insert(
289         {computation, call_graph->nodes_.size()});
290     // All computations should be unique, so the computation should not already
291     // exist in the map.
292     CHECK(it_added.second);
293     call_graph->nodes_.emplace_back(computation);
294 
295     // Add all callsites in this computation.
296     for (HloInstruction* instruction : computation->instructions()) {
297       call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
298     }
299   }
300 
301   // Add caller callsites to each node.
302   for (const HloComputation* computation : module->computations()) {
303     for (const CallSite& callsite :
304          call_graph->GetNode(computation).callsites()) {
305       for (auto* callee : callsite.called_computations()) {
306         // Add caller callsites.
307         call_graph->GetNode(callee).AddCallerCallSite(callsite);
308       }
309     }
310   }
311 
312   call_graph->SetCallContexts();
313   call_graph->SetNodeDepths();
314 
315   XLA_VLOG_LINES(2, call_graph->ToString());
316 
317   return call_graph;
318 }
319 
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const320 Status CallGraph::VisitNodesInternal(
321     const VisitorFunction& visitor_func, const CallGraphNode& node,
322     absl::flat_hash_set<const CallGraphNode*>* visited) const {
323   auto pair = visited->insert(&node);
324   if (!pair.second) {
325     // Node was not inserted. Node has already been visited.
326     return Status::OK();
327   }
328 
329   for (const HloComputation* computation : node.callees()) {
330     TF_RETURN_IF_ERROR(
331         VisitNodesInternal(visitor_func, GetNode(computation), visited));
332   }
333 
334   return visitor_func(node);
335 }
336 
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const337 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
338                              bool visit_unreachable_nodes) const {
339   absl::flat_hash_set<const CallGraphNode*> visited;
340   if (visit_unreachable_nodes) {
341     // Traverse from all roots in the call graph.
342     for (const CallGraphNode& node : nodes()) {
343       if (node.callers().empty()) {
344         TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
345       }
346     }
347   } else {
348     // Traverse only from the entry computation.
349     TF_RETURN_IF_ERROR(VisitNodesInternal(
350         visitor_func, GetNode(module_->entry_computation()), &visited));
351   }
352 
353   return Status::OK();
354 }
355 
IsFlattened() const356 bool CallGraph::IsFlattened() const {
357   for (const CallGraphNode& node : nodes_) {
358     if (node.context() == CallContext::kBoth) {
359       return false;
360     }
361     if (node.context() == CallContext::kSequential &&
362         node.caller_callsites().size() > 1) {
363       return false;
364     }
365   }
366   return true;
367 }
368 
GetComputationCallers(HloComputation * c)369 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
370     HloComputation* c) {
371   std::vector<HloInstruction*> callers;
372   for (const auto& callsite : GetNode(c).caller_callsites()) {
373     callers.push_back(callsite.instruction());
374   }
375   return callers;
376 }
377 
378 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const379 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
380                                              HloInstruction* b) const {
381   // Lambda which returns the next instruction in the callee->caller chain in
382   // the call graph. This is the unique instruction which calls the computation
383   // containing 'instruction'. If more than one instruction calls the
384   // computation containing 'instruction' or no instructions call the
385   // computation then nullptr is returned.
386   auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
387     const CallGraphNode& node = GetNode(instruction->parent());
388     if (node.caller_callsites().size() != 1) {
389       return nullptr;
390     }
391     return node.caller_callsites()[0].instruction();
392   };
393 
394   // Iterate through the callee->caller chains and find the earliest common
395   // element.
396   HloInstruction* a_ancestor = a;
397   HloInstruction* b_ancestor = b;
398   int a_depth = GetNode(a->parent()).depth();
399   int b_depth = GetNode(b->parent()).depth();
400 
401   // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
402   // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
403   // reduces the depth by exactly one.
404   if (a_depth > b_depth) {
405     for (int i = 0; i < a_depth - b_depth; ++i) {
406       a_ancestor = next_caller(a_ancestor);
407       if (a_ancestor == nullptr) {
408         return {nullptr, nullptr};
409       }
410     }
411   } else if (b_depth > a_depth) {
412     for (int i = 0; i < b_depth - a_depth; ++i) {
413       b_ancestor = next_caller(b_ancestor);
414       if (b_ancestor == nullptr) {
415         return {nullptr, nullptr};
416       }
417     }
418   }
419 
420   while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
421     if (a_ancestor->parent() == b_ancestor->parent()) {
422       return {a_ancestor, b_ancestor};
423     }
424 
425     a_ancestor = next_caller(a_ancestor);
426     b_ancestor = next_caller(b_ancestor);
427   }
428   return {nullptr, nullptr};
429 }
430 
ToString() const431 string CallGraph::ToString() const {
432   string out;
433   StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
434   for (const CallGraphNode& node : nodes()) {
435     StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
436     StrAppendFormat(&out, "  calls:\n");
437     for (const HloComputation* callee : node.callees()) {
438       StrAppendFormat(&out, "    %s\n", callee->name());
439     }
440     StrAppendFormat(&out, "  called by:\n");
441     for (const HloComputation* caller : node.callers()) {
442       StrAppendFormat(&out, "    %s\n", caller->name());
443     }
444     StrAppendFormat(&out, "  callsites:\n");
445     for (const CallSite& callsite : node.callsites()) {
446       StrAppendFormat(&out, "    %s\n", callsite.ToString());
447     }
448   }
449   return out;
450 }
451 
452 }  // namespace xla
453