1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28
29 namespace xla {
30 namespace {
31
32 using ::testing::UnorderedElementsAre;
33
34 class CallGraphTest : public HloTestBase {
35 protected:
36 // Build and return a trivial computation taking and returning a scalar.
MakeScalarComputation(HloOpcode opcode=HloOpcode::kNegate)37 std::unique_ptr<HloComputation> MakeScalarComputation(
38 HloOpcode opcode = HloOpcode::kNegate) {
39 HloComputation::Builder builder(TestName() + ".ScalarComputation");
40 HloInstruction* param0 = builder.AddInstruction(
41 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
42 builder.AddInstruction(
43 HloInstruction::CreateUnary(kScalarShape, opcode, param0));
44 return builder.Build();
45 }
46
47 // Build and return a computation which takes a scalar and maps (kMap) the
48 // given computation to the value 'callsites' number of times.
MakeMappingComputation(HloComputation * map_computation,int64 callsites)49 std::unique_ptr<HloComputation> MakeMappingComputation(
50 HloComputation* map_computation, int64 callsites) {
51 HloComputation::Builder builder(TestName() + ".MappingComputation");
52 HloInstruction* param0 = builder.AddInstruction(
53 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
54 HloInstruction* last_value = param0;
55 for (int64 i = 0; i < callsites; ++i) {
56 last_value = builder.AddInstruction(HloInstruction::CreateMap(
57 kScalarShape, {last_value}, map_computation));
58 }
59 return builder.Build();
60 }
61
62 // Build and return a computation which takes a scalar and calls (kCall) the
63 // given computation with value 'callsites' number of times.
MakeCallingComputation(HloComputation * callee_computation,int64 callsites,const string & suffix=".CallingComputation")64 std::unique_ptr<HloComputation> MakeCallingComputation(
65 HloComputation* callee_computation, int64 callsites,
66 const string& suffix = ".CallingComputation") {
67 HloComputation::Builder builder(TestName() + suffix);
68 HloInstruction* param0 = builder.AddInstruction(
69 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
70 HloInstruction* last_value = param0;
71 for (int64 i = 0; i < callsites; ++i) {
72 last_value = builder.AddInstruction(HloInstruction::CreateCall(
73 kScalarShape, {last_value}, callee_computation));
74 }
75 return builder.Build();
76 }
77
78 // Build and return a computation which takes a scalar and returns a PRED
79 // value.
MakeConditionComputation()80 std::unique_ptr<HloComputation> MakeConditionComputation() {
81 HloComputation::Builder builder(TestName() + ".ConditionComputation");
82 HloInstruction* param0 = builder.AddInstruction(
83 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
84 HloInstruction* zero = builder.AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
86 builder.AddInstruction(
87 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
88 zero, ComparisonDirection::kGt));
89 return builder.Build();
90 }
91
92 const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
93 };
94
TEST_F(CallGraphTest,SingletonComputation)95 TEST_F(CallGraphTest, SingletonComputation) {
96 // Test the call graph of a module with a single computation.
97 auto module = CreateNewVerifiedModule();
98 HloComputation* computation =
99 module->AddEntryComputation(MakeScalarComputation());
100 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
101 EXPECT_EQ(1, call_graph->nodes().size());
102 EXPECT_TRUE(call_graph->IsFlattened());
103
104 const CallGraphNode& node = call_graph->GetNode(computation);
105 EXPECT_EQ(computation, node.computation());
106 EXPECT_EQ(node.depth(), 0);
107 EXPECT_TRUE(node.callsites().empty());
108 EXPECT_TRUE(node.callees().empty());
109 EXPECT_TRUE(node.caller_callsites().empty());
110 EXPECT_TRUE(node.callers().empty());
111 EXPECT_EQ(CallContext::kSequential, node.context());
112 }
113
TEST_F(CallGraphTest,UnreachableComputation)114 TEST_F(CallGraphTest, UnreachableComputation) {
115 // Test the call graph of a module with an entry computation and an
116 // unreachable computation.
117 auto module = CreateNewVerifiedModule();
118 HloComputation* entry_computation =
119 module->AddEntryComputation(MakeScalarComputation());
120 HloComputation* unreachable_computation =
121 module->AddEmbeddedComputation(MakeScalarComputation());
122
123 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
124 EXPECT_EQ(2, call_graph->nodes().size());
125
126 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
127 EXPECT_EQ(entry_node.depth(), 0);
128 EXPECT_EQ(entry_computation, entry_node.computation());
129 EXPECT_EQ(CallContext::kSequential, entry_node.context());
130
131 const CallGraphNode& unreachable_node =
132 call_graph->GetNode(unreachable_computation);
133 EXPECT_EQ(unreachable_node.depth(), 0);
134 EXPECT_EQ(unreachable_computation, unreachable_node.computation());
135 EXPECT_EQ(CallContext::kSequential, unreachable_node.context());
136 }
137
TEST_F(CallGraphTest,ParallelComputation)138 TEST_F(CallGraphTest, ParallelComputation) {
139 // Test a call graph of a module with an entry computation which calls another
140 // computation in a parallel context via kMap.
141 auto module = CreateNewVerifiedModule();
142 HloComputation* map_computation =
143 module->AddEmbeddedComputation(MakeScalarComputation());
144 HloComputation* entry_computation = module->AddEntryComputation(
145 MakeMappingComputation(map_computation, /*callsites=*/5));
146
147 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
148 EXPECT_EQ(2, call_graph->nodes().size());
149
150 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
151 EXPECT_EQ(entry_computation, entry_node.computation());
152 EXPECT_EQ(entry_node.depth(), 0);
153 EXPECT_EQ(CallContext::kSequential, entry_node.context());
154 EXPECT_EQ(5, entry_node.callsites().size());
155 EXPECT_EQ(1, entry_node.callees().size());
156 EXPECT_TRUE(entry_node.caller_callsites().empty());
157 EXPECT_TRUE(entry_node.callers().empty());
158
159 const CallGraphNode& map_node = call_graph->GetNode(map_computation);
160 EXPECT_EQ(map_computation, map_node.computation());
161 EXPECT_EQ(map_node.depth(), 1);
162 EXPECT_EQ(CallContext::kParallel, map_node.context());
163 EXPECT_TRUE(map_node.callsites().empty());
164 EXPECT_TRUE(map_node.callees().empty());
165 EXPECT_EQ(5, map_node.caller_callsites().size());
166 EXPECT_EQ(1, map_node.callers().size());
167 }
168
TEST_F(CallGraphTest,SequentialComputations)169 TEST_F(CallGraphTest, SequentialComputations) {
170 // Test a call graph of a module with an entry computation which calls another
171 // computation in a sequential context via kCall.
172 auto module = CreateNewVerifiedModule();
173 HloComputation* called_computation =
174 module->AddEmbeddedComputation(MakeScalarComputation());
175 HloComputation* entry_computation = module->AddEntryComputation(
176 MakeCallingComputation(called_computation, /*callsites=*/3));
177
178 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
179 EXPECT_EQ(2, call_graph->nodes().size());
180
181 // The called computation is only called from one other computation, but there
182 // are multiple callsites.
183 EXPECT_FALSE(call_graph->IsFlattened());
184
185 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
186 EXPECT_EQ(entry_computation, entry_node.computation());
187 EXPECT_EQ(CallContext::kSequential, entry_node.context());
188 EXPECT_EQ(3, entry_node.callsites().size());
189 EXPECT_EQ(1, entry_node.callees().size());
190 EXPECT_TRUE(entry_node.caller_callsites().empty());
191 EXPECT_TRUE(entry_node.callers().empty());
192
193 const CallGraphNode& called_node = call_graph->GetNode(called_computation);
194 EXPECT_EQ(called_computation, called_node.computation());
195 EXPECT_EQ(CallContext::kSequential, called_node.context());
196 EXPECT_TRUE(called_node.callsites().empty());
197 EXPECT_TRUE(called_node.callees().empty());
198 EXPECT_EQ(3, called_node.caller_callsites().size());
199 EXPECT_EQ(1, called_node.callers().size());
200 }
201
TEST_F(CallGraphTest,ContextBothComputations)202 TEST_F(CallGraphTest, ContextBothComputations) {
203 // Test a call graph of a module with an entry computation which calls another
204 // computation in both a parallel and sequential context.
205 auto module = CreateNewVerifiedModule();
206 HloComputation* subcomputation =
207 module->AddEmbeddedComputation(MakeScalarComputation());
208
209 HloComputation::Builder builder(TestName());
210 HloInstruction* param0 = builder.AddInstruction(
211 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
212 HloInstruction* call = builder.AddInstruction(
213 HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation));
214 HloInstruction* map = builder.AddInstruction(
215 HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
216 HloComputation* entry_computation =
217 module->AddEntryComputation(builder.Build());
218
219 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
220 EXPECT_EQ(2, call_graph->nodes().size());
221
222 EXPECT_FALSE(call_graph->IsFlattened());
223
224 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
225 EXPECT_EQ(entry_computation, entry_node.computation());
226 EXPECT_EQ(2, entry_node.callsites().size());
227
228 const CallSite& call_callsite = entry_node.callsites()[0];
229 EXPECT_EQ(call, call_callsite.instruction());
230 EXPECT_THAT(call_callsite.called_computations(),
231 UnorderedElementsAre(subcomputation));
232 EXPECT_EQ(CallContext::kSequential, call_callsite.context());
233 EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite);
234
235 const CallSite& map_callsite = entry_node.callsites()[1];
236 EXPECT_EQ(map, map_callsite.instruction());
237 EXPECT_THAT(map_callsite.called_computations(),
238 UnorderedElementsAre(subcomputation));
239 EXPECT_EQ(CallContext::kParallel, map_callsite.context());
240 EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite);
241
242 const CallGraphNode& sub_node = call_graph->GetNode(subcomputation);
243 EXPECT_EQ(sub_node.depth(), 1);
244 EXPECT_EQ(CallContext::kBoth, sub_node.context());
245 }
246
TEST_F(CallGraphTest,ComputationWithConditional)247 TEST_F(CallGraphTest, ComputationWithConditional) {
248 // Test a call graph of a module with a conditional.
249 auto module = CreateNewVerifiedModule();
250 HloComputation* true_computation =
251 module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil));
252 HloComputation* false_computation =
253 module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor));
254
255 HloComputation::Builder builder(TestName());
256 HloInstruction* pred = builder.AddInstruction(
257 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
258 HloInstruction* const1 = builder.AddInstruction(
259 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
260 HloInstruction* const2 = builder.AddInstruction(
261 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
262 HloInstruction* conditional =
263 builder.AddInstruction(HloInstruction::CreateConditional(
264 kScalarShape, pred, const1, true_computation, const2,
265 false_computation));
266 HloComputation* entry_computation =
267 module->AddEntryComputation(builder.Build());
268
269 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
270
271 EXPECT_EQ(3, call_graph->nodes().size());
272
273 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
274 EXPECT_EQ(entry_node.depth(), 0);
275 EXPECT_EQ(entry_computation, entry_node.computation());
276 EXPECT_EQ(1, entry_node.callsites().size());
277
278 const CallSite& conditional_callsite = entry_node.callsites()[0];
279 EXPECT_EQ(conditional, conditional_callsite.instruction());
280 EXPECT_THAT(conditional_callsite.called_computations(),
281 UnorderedElementsAre(true_computation, false_computation));
282 EXPECT_EQ(CallContext::kSequential, conditional_callsite.context());
283 EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite);
284
285 const CallGraphNode& true_node = call_graph->GetNode(true_computation);
286 EXPECT_EQ(true_node.depth(), 1);
287 EXPECT_TRUE(true_node.callees().empty());
288 EXPECT_EQ(1, true_node.callers().size());
289 EXPECT_EQ(entry_computation, true_node.callers()[0]);
290
291 const CallGraphNode& false_node = call_graph->GetNode(false_computation);
292 EXPECT_EQ(false_node.depth(), 1);
293 EXPECT_TRUE(false_node.callees().empty());
294 EXPECT_EQ(1, false_node.callers().size());
295 EXPECT_EQ(entry_computation, false_node.callers()[0]);
296 }
297
TEST_F(CallGraphTest,ComplexGraph)298 TEST_F(CallGraphTest, ComplexGraph) {
299 // Test a call graph of a module with several computation called in various
300 // contexts. The call graph looks like:
301 //
302 // entry
303 // / |
304 // a |
305 // / | \ |
306 // b | cond
307 // \ |
308 // c
309 //
310 // Calls are made via kCall, kWhile, and kMap instructions.
311 auto module = CreateNewVerifiedModule();
312 HloComputation* cond_computation =
313 module->AddEmbeddedComputation(MakeConditionComputation());
314 HloComputation* c_computation =
315 module->AddEmbeddedComputation(MakeScalarComputation());
316 HloComputation* b_computation = module->AddEmbeddedComputation(
317 MakeMappingComputation(c_computation, /*callsites=*/1));
318
319 HloComputation* a_computation;
320 {
321 HloComputation::Builder builder(TestName() + ".a");
322 HloInstruction* param0 = builder.AddInstruction(
323 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
324 HloInstruction* call = builder.AddInstruction(
325 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
326 builder.AddInstruction(HloInstruction::CreateWhile(
327 kScalarShape, cond_computation, b_computation, call));
328 a_computation = module->AddEmbeddedComputation(builder.Build());
329 }
330
331 HloComputation* entry_computation;
332 {
333 HloComputation::Builder builder(TestName() + ".entry");
334 HloInstruction* param0 = builder.AddInstruction(
335 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
336 builder.AddInstruction(HloInstruction::CreateWhile(
337 kScalarShape, cond_computation, a_computation, param0));
338 entry_computation = module->AddEntryComputation(builder.Build());
339 }
340
341 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
342 EXPECT_EQ(5, call_graph->nodes().size());
343 EXPECT_FALSE(call_graph->IsFlattened());
344
345 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
346 const CallGraphNode& a_node = call_graph->GetNode(a_computation);
347 const CallGraphNode& b_node = call_graph->GetNode(b_computation);
348 const CallGraphNode& c_node = call_graph->GetNode(c_computation);
349 const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
350
351 // Verify depths.
352 EXPECT_EQ(entry_node.depth(), 0);
353 EXPECT_EQ(a_node.depth(), 1);
354 EXPECT_EQ(b_node.depth(), 2);
355 EXPECT_EQ(c_node.depth(), 3);
356 EXPECT_EQ(cond_node.depth(), 2);
357
358 // Entry computation has one while instruction calling two computations
359 // (cond_computation and a_computation).
360 ASSERT_EQ(1, entry_node.callsites().size());
361 const std::vector<HloComputation*>& called_computations =
362 entry_node.callsites()[0].called_computations();
363 EXPECT_THAT(called_computations,
364 UnorderedElementsAre(cond_computation, a_computation));
365 EXPECT_EQ(CallContext::kSequential, entry_node.context());
366
367 EXPECT_TRUE(c_node.callsites().empty());
368 EXPECT_THAT(c_node.callers(),
369 UnorderedElementsAre(a_computation, b_computation));
370 EXPECT_EQ(CallContext::kBoth, c_node.context());
371
372 // Visit the graph and verify nodes were visited in callee-before-caller
373 // order.
374 std::vector<const HloComputation*> visited;
375 TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
376 visited.push_back(node.computation());
377 return Status::OK();
378 }));
379 EXPECT_EQ(visited.size(), 5);
380 // All values in visited should be unique.
381 EXPECT_EQ(
382 std::unordered_set<const HloComputation*>(visited.begin(), visited.end())
383 .size(),
384 5);
385
386 // Verify visitation order of some computations in the graph.
387 auto index_of = [&visited](const HloComputation* comp) {
388 auto it = absl::c_find(visited, comp);
389 EXPECT_NE(it, visited.end());
390 return std::distance(visited.begin(), it);
391 };
392 EXPECT_EQ(4, index_of(entry_computation));
393 EXPECT_LT(index_of(cond_computation), index_of(a_computation));
394 EXPECT_LT(index_of(c_computation), index_of(b_computation));
395 EXPECT_LT(index_of(b_computation), index_of(a_computation));
396
397 // Verify dominance relations between computation in the graph.
398
399 // Entry dominates everybody, and is dominated by no one except itself.
400 EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation));
401 EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation));
402 EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation));
403 EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation));
404 EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation));
405 EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation));
406 EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation));
407 EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation));
408 EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation));
409
410 // 'a' only dominates 'b' and 'c'.
411 EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation));
412 EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation));
413 EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation));
414 EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation));
415 EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation));
416 EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation));
417
418 EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation));
419 EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation));
420 EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation));
421
422 EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation));
423 EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation));
424 EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation));
425
426 EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation));
427 }
428
TEST_F(CallGraphTest,ComplexGraphNearestAncestors)429 TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
430 // Test NearestAncestorsInSameComputation on a call graph of a module with
431 // several computation called in various contexts. The call graph looks like:
432 //
433 // entry
434 // / |
435 // a |
436 // / | \ |
437 // b | cond
438 // \ |
439 // c
440 //
441 // Calls are made via kCall, kWhile, and kMap instructions.
442 auto module = CreateNewVerifiedModule();
443 HloComputation* cond_computation =
444 module->AddEmbeddedComputation(MakeConditionComputation());
445 HloComputation* c_computation =
446 module->AddEmbeddedComputation(MakeScalarComputation());
447 HloComputation* b_computation = module->AddEmbeddedComputation(
448 MakeMappingComputation(c_computation, /*callsites=*/1));
449 HloInstruction* b_map = b_computation->root_instruction();
450
451 HloComputation* a_computation;
452 HloInstruction* a_call;
453 HloInstruction* a_while;
454 {
455 HloComputation::Builder builder(TestName() + ".a");
456 HloInstruction* param0 = builder.AddInstruction(
457 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
458 a_call = builder.AddInstruction(
459 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
460 a_while = builder.AddInstruction(HloInstruction::CreateWhile(
461 kScalarShape, cond_computation, b_computation, a_call));
462 a_computation = module->AddEmbeddedComputation(builder.Build());
463 }
464
465 HloComputation* entry_computation;
466 HloInstruction* entry_while;
467 {
468 HloComputation::Builder builder(TestName() + ".entry");
469 HloInstruction* param0 = builder.AddInstruction(
470 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
471 entry_while = builder.AddInstruction(HloInstruction::CreateWhile(
472 kScalarShape, cond_computation, a_computation, param0));
473 entry_computation = module->AddEntryComputation(builder.Build());
474 }
475
476 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
477 EXPECT_EQ(5, call_graph->nodes().size());
478
479 // Verify NearestAncestorsInSameComputation for various instructions in the
480 // module.
481 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call),
482 std::make_pair(a_call, a_call));
483
484 // c_computation is called from more than one site, so
485 // NearestAncestorsInSameComputation bails and returns nullptrs.
486 std::pair<HloInstruction*, HloInstruction*> null_pair = {nullptr, nullptr};
487 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(
488 b_map, c_computation->root_instruction()),
489 null_pair);
490
491 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while),
492 std::make_pair(entry_while, entry_while));
493 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call),
494 std::make_pair(a_while, a_call));
495 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call),
496 std::make_pair(a_while, a_call));
497 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map),
498 std::make_pair(a_while, a_while));
499 }
500
TEST_F(CallGraphTest,VisitSingletonComputation)501 TEST_F(CallGraphTest, VisitSingletonComputation) {
502 // Test the call graph visitor with a call graph with a single node.
503 auto module = CreateNewVerifiedModule();
504 HloComputation* computation =
505 module->AddEntryComputation(MakeScalarComputation());
506 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
507
508 std::vector<HloComputation*> visited;
509 TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
510 visited.push_back(node.computation());
511 return Status::OK();
512 }));
513 EXPECT_THAT(visited, UnorderedElementsAre(computation));
514 }
515
TEST_F(CallGraphTest,VisitUnreachableComputation)516 TEST_F(CallGraphTest, VisitUnreachableComputation) {
517 // Test the call graph visitor with a call graph with an unreachable node.
518 auto module = CreateNewVerifiedModule();
519 HloComputation* entry_computation =
520 module->AddEntryComputation(MakeScalarComputation());
521 HloComputation* unreachable_computation =
522 module->AddEmbeddedComputation(MakeScalarComputation());
523 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
524
525 // Test visitation of only reachable nodes.
526 {
527 std::vector<const HloComputation*> visited;
528 TF_ASSERT_OK(call_graph->VisitNodes(
529 [&visited](const CallGraphNode& node) {
530 visited.push_back(node.computation());
531 return Status::OK();
532 },
533 /*visit_unreachable_nodes=*/false));
534 EXPECT_EQ(visited.size(), 1);
535 EXPECT_EQ(visited[0], entry_computation);
536 }
537
538 // Test visitation of all nodes (reachable and unreachable).
539 {
540 std::vector<HloComputation*> visited;
541 TF_ASSERT_OK(call_graph->VisitNodes(
542 [&visited](const CallGraphNode& node) {
543 visited.push_back(node.computation());
544 return Status::OK();
545 },
546 /*visit_unreachable_nodes=*/true));
547 EXPECT_EQ(visited.size(), 2);
548 EXPECT_THAT(visited, UnorderedElementsAre(entry_computation,
549 unreachable_computation));
550 }
551 }
552
TEST_F(CallGraphTest,VisitWithError)553 TEST_F(CallGraphTest, VisitWithError) {
554 // Test that the call graph visitor properly propagates errors.
555 auto module = CreateNewVerifiedModule();
556 module->AddEntryComputation(MakeScalarComputation());
557 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
558
559 Status status = call_graph->VisitNodes(
560 [](const CallGraphNode&) { return InternalError("Visitation failed"); });
561
562 ASSERT_FALSE(status.ok());
563 ASSERT_EQ(status.code(), tensorflow::error::INTERNAL);
564 ASSERT_THAT(status.error_message(),
565 ::testing::HasSubstr("Visitation failed"));
566 }
567
568 } // namespace
569 } // namespace xla
570