1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 
29 namespace xla {
30 namespace {
31 
32 using ::testing::UnorderedElementsAre;
33 
34 class CallGraphTest : public HloTestBase {
35  protected:
36   // Build and return a trivial computation taking and returning a scalar.
MakeScalarComputation(HloOpcode opcode=HloOpcode::kNegate)37   std::unique_ptr<HloComputation> MakeScalarComputation(
38       HloOpcode opcode = HloOpcode::kNegate) {
39     HloComputation::Builder builder(TestName() + ".ScalarComputation");
40     HloInstruction* param0 = builder.AddInstruction(
41         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
42     builder.AddInstruction(
43         HloInstruction::CreateUnary(kScalarShape, opcode, param0));
44     return builder.Build();
45   }
46 
47   // Build and return a computation which takes a scalar and maps (kMap) the
48   // given computation to the value 'callsites' number of times.
MakeMappingComputation(HloComputation * map_computation,int64 callsites)49   std::unique_ptr<HloComputation> MakeMappingComputation(
50       HloComputation* map_computation, int64 callsites) {
51     HloComputation::Builder builder(TestName() + ".MappingComputation");
52     HloInstruction* param0 = builder.AddInstruction(
53         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
54     HloInstruction* last_value = param0;
55     for (int64 i = 0; i < callsites; ++i) {
56       last_value = builder.AddInstruction(HloInstruction::CreateMap(
57           kScalarShape, {last_value}, map_computation));
58     }
59     return builder.Build();
60   }
61 
62   // Build and return a computation which takes a scalar and calls (kCall) the
63   // given computation with value 'callsites' number of times.
MakeCallingComputation(HloComputation * callee_computation,int64 callsites,const string & suffix=".CallingComputation")64   std::unique_ptr<HloComputation> MakeCallingComputation(
65       HloComputation* callee_computation, int64 callsites,
66       const string& suffix = ".CallingComputation") {
67     HloComputation::Builder builder(TestName() + suffix);
68     HloInstruction* param0 = builder.AddInstruction(
69         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
70     HloInstruction* last_value = param0;
71     for (int64 i = 0; i < callsites; ++i) {
72       last_value = builder.AddInstruction(HloInstruction::CreateCall(
73           kScalarShape, {last_value}, callee_computation));
74     }
75     return builder.Build();
76   }
77 
78   // Build and return a computation which takes a scalar and returns a PRED
79   // value.
MakeConditionComputation()80   std::unique_ptr<HloComputation> MakeConditionComputation() {
81     HloComputation::Builder builder(TestName() + ".ConditionComputation");
82     HloInstruction* param0 = builder.AddInstruction(
83         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
84     HloInstruction* zero = builder.AddInstruction(
85         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
86     builder.AddInstruction(
87         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
88                                       zero, ComparisonDirection::kGt));
89     return builder.Build();
90   }
91 
92   const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
93 };
94 
TEST_F(CallGraphTest,SingletonComputation)95 TEST_F(CallGraphTest, SingletonComputation) {
96   // Test the call graph of a module with a single computation.
97   auto module = CreateNewVerifiedModule();
98   HloComputation* computation =
99       module->AddEntryComputation(MakeScalarComputation());
100   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
101   EXPECT_EQ(1, call_graph->nodes().size());
102   EXPECT_TRUE(call_graph->IsFlattened());
103 
104   const CallGraphNode& node = call_graph->GetNode(computation);
105   EXPECT_EQ(computation, node.computation());
106   EXPECT_EQ(node.depth(), 0);
107   EXPECT_TRUE(node.callsites().empty());
108   EXPECT_TRUE(node.callees().empty());
109   EXPECT_TRUE(node.caller_callsites().empty());
110   EXPECT_TRUE(node.callers().empty());
111   EXPECT_EQ(CallContext::kSequential, node.context());
112 }
113 
TEST_F(CallGraphTest,UnreachableComputation)114 TEST_F(CallGraphTest, UnreachableComputation) {
115   // Test the call graph of a module with an entry computation and an
116   // unreachable computation.
117   auto module = CreateNewVerifiedModule();
118   HloComputation* entry_computation =
119       module->AddEntryComputation(MakeScalarComputation());
120   HloComputation* unreachable_computation =
121       module->AddEmbeddedComputation(MakeScalarComputation());
122 
123   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
124   EXPECT_EQ(2, call_graph->nodes().size());
125 
126   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
127   EXPECT_EQ(entry_node.depth(), 0);
128   EXPECT_EQ(entry_computation, entry_node.computation());
129   EXPECT_EQ(CallContext::kSequential, entry_node.context());
130 
131   const CallGraphNode& unreachable_node =
132       call_graph->GetNode(unreachable_computation);
133   EXPECT_EQ(unreachable_node.depth(), 0);
134   EXPECT_EQ(unreachable_computation, unreachable_node.computation());
135   EXPECT_EQ(CallContext::kSequential, unreachable_node.context());
136 }
137 
TEST_F(CallGraphTest,ParallelComputation)138 TEST_F(CallGraphTest, ParallelComputation) {
139   // Test a call graph of a module with an entry computation which calls another
140   // computation in a parallel context via kMap.
141   auto module = CreateNewVerifiedModule();
142   HloComputation* map_computation =
143       module->AddEmbeddedComputation(MakeScalarComputation());
144   HloComputation* entry_computation = module->AddEntryComputation(
145       MakeMappingComputation(map_computation, /*callsites=*/5));
146 
147   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
148   EXPECT_EQ(2, call_graph->nodes().size());
149 
150   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
151   EXPECT_EQ(entry_computation, entry_node.computation());
152   EXPECT_EQ(entry_node.depth(), 0);
153   EXPECT_EQ(CallContext::kSequential, entry_node.context());
154   EXPECT_EQ(5, entry_node.callsites().size());
155   EXPECT_EQ(1, entry_node.callees().size());
156   EXPECT_TRUE(entry_node.caller_callsites().empty());
157   EXPECT_TRUE(entry_node.callers().empty());
158 
159   const CallGraphNode& map_node = call_graph->GetNode(map_computation);
160   EXPECT_EQ(map_computation, map_node.computation());
161   EXPECT_EQ(map_node.depth(), 1);
162   EXPECT_EQ(CallContext::kParallel, map_node.context());
163   EXPECT_TRUE(map_node.callsites().empty());
164   EXPECT_TRUE(map_node.callees().empty());
165   EXPECT_EQ(5, map_node.caller_callsites().size());
166   EXPECT_EQ(1, map_node.callers().size());
167 }
168 
TEST_F(CallGraphTest,SequentialComputations)169 TEST_F(CallGraphTest, SequentialComputations) {
170   // Test a call graph of a module with an entry computation which calls another
171   // computation in a sequential context via kCall.
172   auto module = CreateNewVerifiedModule();
173   HloComputation* called_computation =
174       module->AddEmbeddedComputation(MakeScalarComputation());
175   HloComputation* entry_computation = module->AddEntryComputation(
176       MakeCallingComputation(called_computation, /*callsites=*/3));
177 
178   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
179   EXPECT_EQ(2, call_graph->nodes().size());
180 
181   // The called computation is only called from one other computation, but there
182   // are multiple callsites.
183   EXPECT_FALSE(call_graph->IsFlattened());
184 
185   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
186   EXPECT_EQ(entry_computation, entry_node.computation());
187   EXPECT_EQ(CallContext::kSequential, entry_node.context());
188   EXPECT_EQ(3, entry_node.callsites().size());
189   EXPECT_EQ(1, entry_node.callees().size());
190   EXPECT_TRUE(entry_node.caller_callsites().empty());
191   EXPECT_TRUE(entry_node.callers().empty());
192 
193   const CallGraphNode& called_node = call_graph->GetNode(called_computation);
194   EXPECT_EQ(called_computation, called_node.computation());
195   EXPECT_EQ(CallContext::kSequential, called_node.context());
196   EXPECT_TRUE(called_node.callsites().empty());
197   EXPECT_TRUE(called_node.callees().empty());
198   EXPECT_EQ(3, called_node.caller_callsites().size());
199   EXPECT_EQ(1, called_node.callers().size());
200 }
201 
TEST_F(CallGraphTest,ContextBothComputations)202 TEST_F(CallGraphTest, ContextBothComputations) {
203   // Test a call graph of a module with an entry computation which calls another
204   // computation in both a parallel and sequential context.
205   auto module = CreateNewVerifiedModule();
206   HloComputation* subcomputation =
207       module->AddEmbeddedComputation(MakeScalarComputation());
208 
209   HloComputation::Builder builder(TestName());
210   HloInstruction* param0 = builder.AddInstruction(
211       HloInstruction::CreateParameter(0, kScalarShape, "param0"));
212   HloInstruction* call = builder.AddInstruction(
213       HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation));
214   HloInstruction* map = builder.AddInstruction(
215       HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
216   HloComputation* entry_computation =
217       module->AddEntryComputation(builder.Build());
218 
219   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
220   EXPECT_EQ(2, call_graph->nodes().size());
221 
222   EXPECT_FALSE(call_graph->IsFlattened());
223 
224   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
225   EXPECT_EQ(entry_computation, entry_node.computation());
226   EXPECT_EQ(2, entry_node.callsites().size());
227 
228   const CallSite& call_callsite = entry_node.callsites()[0];
229   EXPECT_EQ(call, call_callsite.instruction());
230   EXPECT_THAT(call_callsite.called_computations(),
231               UnorderedElementsAre(subcomputation));
232   EXPECT_EQ(CallContext::kSequential, call_callsite.context());
233   EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite);
234 
235   const CallSite& map_callsite = entry_node.callsites()[1];
236   EXPECT_EQ(map, map_callsite.instruction());
237   EXPECT_THAT(map_callsite.called_computations(),
238               UnorderedElementsAre(subcomputation));
239   EXPECT_EQ(CallContext::kParallel, map_callsite.context());
240   EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite);
241 
242   const CallGraphNode& sub_node = call_graph->GetNode(subcomputation);
243   EXPECT_EQ(sub_node.depth(), 1);
244   EXPECT_EQ(CallContext::kBoth, sub_node.context());
245 }
246 
TEST_F(CallGraphTest,ComputationWithConditional)247 TEST_F(CallGraphTest, ComputationWithConditional) {
248   // Test a call graph of a module with a conditional.
249   auto module = CreateNewVerifiedModule();
250   HloComputation* true_computation =
251       module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil));
252   HloComputation* false_computation =
253       module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor));
254 
255   HloComputation::Builder builder(TestName());
256   HloInstruction* pred = builder.AddInstruction(
257       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
258   HloInstruction* const1 = builder.AddInstruction(
259       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
260   HloInstruction* const2 = builder.AddInstruction(
261       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
262   HloInstruction* conditional =
263       builder.AddInstruction(HloInstruction::CreateConditional(
264           kScalarShape, pred, const1, true_computation, const2,
265           false_computation));
266   HloComputation* entry_computation =
267       module->AddEntryComputation(builder.Build());
268 
269   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
270 
271   EXPECT_EQ(3, call_graph->nodes().size());
272 
273   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
274   EXPECT_EQ(entry_node.depth(), 0);
275   EXPECT_EQ(entry_computation, entry_node.computation());
276   EXPECT_EQ(1, entry_node.callsites().size());
277 
278   const CallSite& conditional_callsite = entry_node.callsites()[0];
279   EXPECT_EQ(conditional, conditional_callsite.instruction());
280   EXPECT_THAT(conditional_callsite.called_computations(),
281               UnorderedElementsAre(true_computation, false_computation));
282   EXPECT_EQ(CallContext::kSequential, conditional_callsite.context());
283   EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite);
284 
285   const CallGraphNode& true_node = call_graph->GetNode(true_computation);
286   EXPECT_EQ(true_node.depth(), 1);
287   EXPECT_TRUE(true_node.callees().empty());
288   EXPECT_EQ(1, true_node.callers().size());
289   EXPECT_EQ(entry_computation, true_node.callers()[0]);
290 
291   const CallGraphNode& false_node = call_graph->GetNode(false_computation);
292   EXPECT_EQ(false_node.depth(), 1);
293   EXPECT_TRUE(false_node.callees().empty());
294   EXPECT_EQ(1, false_node.callers().size());
295   EXPECT_EQ(entry_computation, false_node.callers()[0]);
296 }
297 
TEST_F(CallGraphTest,ComplexGraph)298 TEST_F(CallGraphTest, ComplexGraph) {
299   // Test a call graph of a module with several computation called in various
300   // contexts. The call graph looks like:
301   //
302   //      entry
303   //      /  |
304   //     a   |
305   //   / | \ |
306   //  b  |  cond
307   //   \ |
308   //    c
309   //
310   // Calls are made via kCall, kWhile, and kMap instructions.
311   auto module = CreateNewVerifiedModule();
312   HloComputation* cond_computation =
313       module->AddEmbeddedComputation(MakeConditionComputation());
314   HloComputation* c_computation =
315       module->AddEmbeddedComputation(MakeScalarComputation());
316   HloComputation* b_computation = module->AddEmbeddedComputation(
317       MakeMappingComputation(c_computation, /*callsites=*/1));
318 
319   HloComputation* a_computation;
320   {
321     HloComputation::Builder builder(TestName() + ".a");
322     HloInstruction* param0 = builder.AddInstruction(
323         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
324     HloInstruction* call = builder.AddInstruction(
325         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
326     builder.AddInstruction(HloInstruction::CreateWhile(
327         kScalarShape, cond_computation, b_computation, call));
328     a_computation = module->AddEmbeddedComputation(builder.Build());
329   }
330 
331   HloComputation* entry_computation;
332   {
333     HloComputation::Builder builder(TestName() + ".entry");
334     HloInstruction* param0 = builder.AddInstruction(
335         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
336     builder.AddInstruction(HloInstruction::CreateWhile(
337         kScalarShape, cond_computation, a_computation, param0));
338     entry_computation = module->AddEntryComputation(builder.Build());
339   }
340 
341   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
342   EXPECT_EQ(5, call_graph->nodes().size());
343   EXPECT_FALSE(call_graph->IsFlattened());
344 
345   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
346   const CallGraphNode& a_node = call_graph->GetNode(a_computation);
347   const CallGraphNode& b_node = call_graph->GetNode(b_computation);
348   const CallGraphNode& c_node = call_graph->GetNode(c_computation);
349   const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
350 
351   // Verify depths.
352   EXPECT_EQ(entry_node.depth(), 0);
353   EXPECT_EQ(a_node.depth(), 1);
354   EXPECT_EQ(b_node.depth(), 2);
355   EXPECT_EQ(c_node.depth(), 3);
356   EXPECT_EQ(cond_node.depth(), 2);
357 
358   // Entry computation has one while instruction calling two computations
359   // (cond_computation and a_computation).
360   ASSERT_EQ(1, entry_node.callsites().size());
361   const std::vector<HloComputation*>& called_computations =
362       entry_node.callsites()[0].called_computations();
363   EXPECT_THAT(called_computations,
364               UnorderedElementsAre(cond_computation, a_computation));
365   EXPECT_EQ(CallContext::kSequential, entry_node.context());
366 
367   EXPECT_TRUE(c_node.callsites().empty());
368   EXPECT_THAT(c_node.callers(),
369               UnorderedElementsAre(a_computation, b_computation));
370   EXPECT_EQ(CallContext::kBoth, c_node.context());
371 
372   // Visit the graph and verify nodes were visited in callee-before-caller
373   // order.
374   std::vector<const HloComputation*> visited;
375   TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
376     visited.push_back(node.computation());
377     return Status::OK();
378   }));
379   EXPECT_EQ(visited.size(), 5);
380   // All values in visited should be unique.
381   EXPECT_EQ(
382       std::unordered_set<const HloComputation*>(visited.begin(), visited.end())
383           .size(),
384       5);
385 
386   // Verify visitation order of some computations in the graph.
387   auto index_of = [&visited](const HloComputation* comp) {
388     auto it = absl::c_find(visited, comp);
389     EXPECT_NE(it, visited.end());
390     return std::distance(visited.begin(), it);
391   };
392   EXPECT_EQ(4, index_of(entry_computation));
393   EXPECT_LT(index_of(cond_computation), index_of(a_computation));
394   EXPECT_LT(index_of(c_computation), index_of(b_computation));
395   EXPECT_LT(index_of(b_computation), index_of(a_computation));
396 
397   // Verify dominance relations between computation in the graph.
398 
399   // Entry dominates everybody, and is dominated by no one except itself.
400   EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation));
401   EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation));
402   EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation));
403   EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation));
404   EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation));
405   EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation));
406   EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation));
407   EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation));
408   EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation));
409 
410   // 'a' only dominates 'b' and 'c'.
411   EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation));
412   EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation));
413   EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation));
414   EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation));
415   EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation));
416   EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation));
417 
418   EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation));
419   EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation));
420   EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation));
421 
422   EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation));
423   EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation));
424   EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation));
425 
426   EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation));
427 }
428 
TEST_F(CallGraphTest,ComplexGraphNearestAncestors)429 TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
430   // Test NearestAncestorsInSameComputation on a call graph of a module with
431   // several computation called in various contexts. The call graph looks like:
432   //
433   //      entry
434   //      /  |
435   //     a   |
436   //   / | \ |
437   //  b  |  cond
438   //   \ |
439   //    c
440   //
441   // Calls are made via kCall, kWhile, and kMap instructions.
442   auto module = CreateNewVerifiedModule();
443   HloComputation* cond_computation =
444       module->AddEmbeddedComputation(MakeConditionComputation());
445   HloComputation* c_computation =
446       module->AddEmbeddedComputation(MakeScalarComputation());
447   HloComputation* b_computation = module->AddEmbeddedComputation(
448       MakeMappingComputation(c_computation, /*callsites=*/1));
449   HloInstruction* b_map = b_computation->root_instruction();
450 
451   HloComputation* a_computation;
452   HloInstruction* a_call;
453   HloInstruction* a_while;
454   {
455     HloComputation::Builder builder(TestName() + ".a");
456     HloInstruction* param0 = builder.AddInstruction(
457         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
458     a_call = builder.AddInstruction(
459         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
460     a_while = builder.AddInstruction(HloInstruction::CreateWhile(
461         kScalarShape, cond_computation, b_computation, a_call));
462     a_computation = module->AddEmbeddedComputation(builder.Build());
463   }
464 
465   HloComputation* entry_computation;
466   HloInstruction* entry_while;
467   {
468     HloComputation::Builder builder(TestName() + ".entry");
469     HloInstruction* param0 = builder.AddInstruction(
470         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
471     entry_while = builder.AddInstruction(HloInstruction::CreateWhile(
472         kScalarShape, cond_computation, a_computation, param0));
473     entry_computation = module->AddEntryComputation(builder.Build());
474   }
475 
476   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
477   EXPECT_EQ(5, call_graph->nodes().size());
478 
479   // Verify NearestAncestorsInSameComputation for various instructions in the
480   // module.
481   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call),
482             std::make_pair(a_call, a_call));
483 
484   // c_computation is called from more than one site, so
485   // NearestAncestorsInSameComputation bails and returns nullptrs.
486   std::pair<HloInstruction*, HloInstruction*> null_pair = {nullptr, nullptr};
487   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(
488                 b_map, c_computation->root_instruction()),
489             null_pair);
490 
491   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while),
492             std::make_pair(entry_while, entry_while));
493   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call),
494             std::make_pair(a_while, a_call));
495   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call),
496             std::make_pair(a_while, a_call));
497   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map),
498             std::make_pair(a_while, a_while));
499 }
500 
TEST_F(CallGraphTest,VisitSingletonComputation)501 TEST_F(CallGraphTest, VisitSingletonComputation) {
502   // Test the call graph visitor with a call graph with a single node.
503   auto module = CreateNewVerifiedModule();
504   HloComputation* computation =
505       module->AddEntryComputation(MakeScalarComputation());
506   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
507 
508   std::vector<HloComputation*> visited;
509   TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
510     visited.push_back(node.computation());
511     return Status::OK();
512   }));
513   EXPECT_THAT(visited, UnorderedElementsAre(computation));
514 }
515 
TEST_F(CallGraphTest,VisitUnreachableComputation)516 TEST_F(CallGraphTest, VisitUnreachableComputation) {
517   // Test the call graph visitor with a call graph with an unreachable node.
518   auto module = CreateNewVerifiedModule();
519   HloComputation* entry_computation =
520       module->AddEntryComputation(MakeScalarComputation());
521   HloComputation* unreachable_computation =
522       module->AddEmbeddedComputation(MakeScalarComputation());
523   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
524 
525   // Test visitation of only reachable nodes.
526   {
527     std::vector<const HloComputation*> visited;
528     TF_ASSERT_OK(call_graph->VisitNodes(
529         [&visited](const CallGraphNode& node) {
530           visited.push_back(node.computation());
531           return Status::OK();
532         },
533         /*visit_unreachable_nodes=*/false));
534     EXPECT_EQ(visited.size(), 1);
535     EXPECT_EQ(visited[0], entry_computation);
536   }
537 
538   // Test visitation of all nodes (reachable and unreachable).
539   {
540     std::vector<HloComputation*> visited;
541     TF_ASSERT_OK(call_graph->VisitNodes(
542         [&visited](const CallGraphNode& node) {
543           visited.push_back(node.computation());
544           return Status::OK();
545         },
546         /*visit_unreachable_nodes=*/true));
547     EXPECT_EQ(visited.size(), 2);
548     EXPECT_THAT(visited, UnorderedElementsAre(entry_computation,
549                                               unreachable_computation));
550   }
551 }
552 
TEST_F(CallGraphTest,VisitWithError)553 TEST_F(CallGraphTest, VisitWithError) {
554   // Test that the call graph visitor properly propagates errors.
555   auto module = CreateNewVerifiedModule();
556   module->AddEntryComputation(MakeScalarComputation());
557   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
558 
559   Status status = call_graph->VisitNodes(
560       [](const CallGraphNode&) { return InternalError("Visitation failed"); });
561 
562   ASSERT_FALSE(status.ok());
563   ASSERT_EQ(status.code(), tensorflow::error::INTERNAL);
564   ASSERT_THAT(status.error_message(),
565               ::testing::HasSubstr("Visitation failed"));
566 }
567 
568 }  // namespace
569 }  // namespace xla
570