1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32
33 namespace xla {
34 namespace {
35
36 class HloOrderingTest : public HloTestBase {};
37
TEST_F(HloOrderingTest,InstructionsInDifferentComputations)38 TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
39 // Tests the ordering of instructions in different computations using the
40 // following HLO code:
41 //
42 // Entry computation:
43 // %x = Call(A, {})
44 // %y = Call(B, {%x})
45 //
46 // Computation A:
47 // %a = Call(C, {})
48 //
49 // Computation B:
50 // %b = Call(C, {})
51 //
52 // Computation C:
53 // %c = Constant(42.0f)
54 //
55 // This results in a diamond-shaped callgraph.
56 auto module = CreateNewVerifiedModule();
57 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
58
59 auto builder_c = HloComputation::Builder("C");
60 HloInstruction* c = builder_c.AddInstruction(
61 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
62 HloComputation* computation_c =
63 module->AddEmbeddedComputation(builder_c.Build());
64
65 auto builder_b = HloComputation::Builder("B");
66 builder_b.AddInstruction(
67 HloInstruction::CreateParameter(0, scalar_shape, "param"));
68 HloInstruction* b = builder_b.AddInstruction(
69 HloInstruction::CreateCall(scalar_shape, {}, computation_c));
70 HloComputation* computation_b =
71 module->AddEmbeddedComputation(builder_b.Build());
72
73 auto builder_a = HloComputation::Builder("A");
74 HloInstruction* a = builder_a.AddInstruction(
75 HloInstruction::CreateCall(scalar_shape, {}, computation_c));
76 HloComputation* computation_a =
77 module->AddEmbeddedComputation(builder_a.Build());
78
79 auto builder = HloComputation::Builder(TestName());
80 HloInstruction* x = builder.AddInstruction(
81 HloInstruction::CreateCall(scalar_shape, {}, computation_a));
82 HloInstruction* y = builder.AddInstruction(
83 HloInstruction::CreateCall(scalar_shape, {x}, computation_b));
84 module->AddEntryComputation(builder.Build());
85
86 DependencyHloOrdering ordering(module.get());
87 EXPECT_TRUE(ordering.ExecutesBefore(x, y));
88 EXPECT_FALSE(ordering.ExecutesBefore(y, x));
89
90 EXPECT_TRUE(ordering.ExecutesBefore(a, b));
91 EXPECT_FALSE(ordering.ExecutesBefore(b, a));
92
93 EXPECT_FALSE(ordering.ExecutesBefore(a, x));
94 EXPECT_TRUE(ordering.ExecutesBefore(a, y));
95 EXPECT_FALSE(ordering.ExecutesBefore(x, a));
96 EXPECT_FALSE(ordering.ExecutesBefore(y, a));
97
98 EXPECT_FALSE(ordering.ExecutesBefore(b, x));
99 EXPECT_FALSE(ordering.ExecutesBefore(b, y));
100 EXPECT_TRUE(ordering.ExecutesBefore(x, b));
101 EXPECT_FALSE(ordering.ExecutesBefore(y, b));
102
103 // Instruction 'c' is called from multiple callsites and should be unordered
104 // relative to all other instructions in the module.
105 EXPECT_FALSE(ordering.ExecutesBefore(c, a));
106 EXPECT_FALSE(ordering.ExecutesBefore(c, b));
107 EXPECT_FALSE(ordering.ExecutesBefore(c, x));
108 EXPECT_FALSE(ordering.ExecutesBefore(c, y));
109 EXPECT_FALSE(ordering.ExecutesBefore(a, c));
110 EXPECT_FALSE(ordering.ExecutesBefore(b, c));
111 EXPECT_FALSE(ordering.ExecutesBefore(x, c));
112 EXPECT_FALSE(ordering.ExecutesBefore(y, c));
113 }
114
TEST_F(HloOrderingTest,InstructionsInWhileComputations)115 TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
116 // Tests the ordering of instructions in the body and condition of a while
117 // instruction. HLO code:
118 //
119 // body(F32[]) %param):
120 // %negate = Negate(%param)
121 //
122 // condition(F32[] %param):
123 // %convert = Convert<PRED>(%param)
124 //
125 // entry:
126 // %constant = Constant(1.0)
127 // return While(%constant, body, condition)
128 //
129 auto module = CreateNewVerifiedModule();
130 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
131
132 auto body_builder = HloComputation::Builder("body");
133 auto body_param = body_builder.AddInstruction(
134 HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
135 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
136 scalar_shape, HloOpcode::kNegate, body_param));
137 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
138
139 auto cond_builder = HloComputation::Builder("condition");
140 auto cond_param = cond_builder.AddInstruction(
141 HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
142 auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
143 ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
144 HloComputation* condition =
145 module->AddEmbeddedComputation(cond_builder.Build());
146
147 auto builder = HloComputation::Builder(TestName());
148 auto constant = builder.AddInstruction(
149 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
150 auto xla_while = builder.AddInstruction(
151 HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
152 module->AddEntryComputation(builder.Build());
153
154 DependencyHloOrdering ordering(module.get());
155 EXPECT_TRUE(ordering.ExecutesBefore(constant, xla_while));
156 EXPECT_TRUE(ordering.ExecutesBefore(constant, cond_param));
157 EXPECT_TRUE(ordering.ExecutesBefore(constant, convert));
158 EXPECT_TRUE(ordering.ExecutesBefore(constant, body_param));
159 EXPECT_TRUE(ordering.ExecutesBefore(constant, negate));
160
161 // The while should be unordered relative to the body and condition
162 // instructions.
163 EXPECT_FALSE(ordering.ExecutesBefore(xla_while, body_param));
164 EXPECT_FALSE(ordering.ExecutesBefore(xla_while, cond_param));
165 EXPECT_FALSE(ordering.ExecutesBefore(body_param, xla_while));
166 EXPECT_FALSE(ordering.ExecutesBefore(cond_param, xla_while));
167
168 // Condition instructions should be ordered before body instructions.
169 EXPECT_TRUE(ordering.ExecutesBefore(cond_param, body_param));
170 EXPECT_TRUE(ordering.ExecutesBefore(convert, body_param));
171 EXPECT_TRUE(ordering.ExecutesBefore(cond_param, negate));
172 EXPECT_TRUE(ordering.ExecutesBefore(convert, negate));
173
174 EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
175 }
176
TEST_F(HloOrderingTest,ParametersDefinedBeforeOthers)177 TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) {
178 // Entry parameter should always be defined before other instruction.
179 auto module = CreateNewVerifiedModule();
180 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
181 auto builder = HloComputation::Builder(TestName());
182 auto constant = builder.AddInstruction(
183 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
184 auto param = builder.AddInstruction(
185 HloInstruction::CreateParameter(0, scalar_shape, "param"));
186 module->AddEntryComputation(builder.Build());
187 TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
188 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
189
190 DependencyHloOrdering ordering(module.get());
191 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param),
192 dataflow->GetValueDefinedAt(constant)));
193 EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
194 dataflow->GetValueDefinedAt(param)));
195 }
196
TEST_F(HloOrderingTest,ValuesInWhileComputations)197 TEST_F(HloOrderingTest, ValuesInWhileComputations) {
198 // Tests the ordering of values (defined by dataflow analysis) in the body and
199 // condition of a while instruction. HLO code:
200 //
201 // body(F32[]) %param):
202 // %negate = Negate(%param)
203 //
204 // condition(F32[] %param):
205 // %convert = Convert<PRED>(%param)
206 //
207 // entry:
208 // %constant = Constant(1.0)
209 // %while = While(%constant, body, condition)
210 // %add = Add(%constant, %while)
211 //
212 auto module = CreateNewVerifiedModule();
213 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
214
215 auto body_builder = HloComputation::Builder("body");
216 auto body_param = body_builder.AddInstruction(
217 HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
218 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
219 scalar_shape, HloOpcode::kNegate, body_param));
220 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
221
222 auto cond_builder = HloComputation::Builder("condition");
223 auto cond_param = cond_builder.AddInstruction(
224 HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
225 auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
226 ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
227 HloComputation* condition =
228 module->AddEmbeddedComputation(cond_builder.Build());
229
230 auto builder = HloComputation::Builder(TestName());
231 auto constant = builder.AddInstruction(
232 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
233 auto xla_while = builder.AddInstruction(
234 HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
235 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
236 scalar_shape, HloOpcode::kAdd, constant, xla_while));
237 module->AddEntryComputation(builder.Build());
238
239 TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
240 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
241 DependencyHloOrdering ordering(module.get());
242
243 // Init value is defined before the while, but live range is not before the
244 // while because of the use of the init value in the add.
245 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
246 dataflow->GetValueDefinedAt(xla_while)));
247 EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
248 dataflow->GetValueDefinedAt(constant),
249 dataflow->GetValueDefinedAt(xla_while), *dataflow));
250 EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
251 dataflow->GetValueDefinedAt(xla_while),
252 *dataflow));
253
254 // Any value defined in the body or condition is defined before the while, and
255 // has a live range strictly before the while.
256 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
257 dataflow->GetValueDefinedAt(xla_while)));
258 EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
259 dataflow->GetValueDefinedAt(negate),
260 dataflow->GetValueDefinedAt(xla_while), *dataflow));
261 EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
262 dataflow->GetValueDefinedAt(xla_while),
263 *dataflow));
264
265 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
266 dataflow->GetValueDefinedAt(xla_while)));
267 EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
268 dataflow->GetValueDefinedAt(convert),
269 dataflow->GetValueDefinedAt(xla_while), *dataflow));
270 EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
271 dataflow->GetValueDefinedAt(xla_while),
272 *dataflow));
273
274 // The live range of the while should be before the add.
275 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
276 dataflow->GetValueDefinedAt(add)));
277 ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
278
279 const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
280 EXPECT_EQ(while_use.instruction, add);
281 EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
282 while_use, dataflow->GetValueDefinedAt(add), *dataflow));
283 EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
284 dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add),
285 *dataflow));
286 }
287
288 // Regression test for HloOrdering::ToString() crashing when fed a computation
289 // containing a fusion node.
TEST_F(HloOrderingTest,ToStringDoesNotCrash)290 TEST_F(HloOrderingTest, ToStringDoesNotCrash) {
291 const char* module_str = R"(
292 HloModule test_module
293
294 body.v8 {
295 prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
296 get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0
297 constant.1 = s32[] constant(1)
298 add = s32[] add(get-tuple-element.4, constant.1)
299 get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3
300 get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1
301 get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2
302 ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7)
303 }
304
305 condition.v4 {
306 constant.2 = s32[] constant(2)
307 prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
308 get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
309 ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT
310 }
311
312 fused_computation {
313 get-tuple-element.5.param_1 = f32[3]{0} parameter(1)
314 get-tuple-element.6.param_2 = f32[3]{0} parameter(2)
315 add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2)
316 get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0)
317 ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1)
318 }
319
320 ENTRY while.v11 {
321 constant.5 = s32[] constant(0)
322 constant.6 = f32[3]{0} constant({1, 1, 1})
323 constant.7 = f32[3]{0} constant({2, 2, 2})
324 constant.8 = f32[3]{0} constant({3, 3, 3})
325 tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8)
326 while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8
327 get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3
328 get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1
329 get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2
330 ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation
331 })";
332
333 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
334 ParseHloString(module_str));
335 DependencyHloOrdering ordering(module.get());
336 ordering.ToString(); // Shouldn't crash.
337 }
338
TEST_F(HloOrderingTest,ConditionalInstructionOrdering)339 TEST_F(HloOrderingTest, ConditionalInstructionOrdering) {
340 const char* module_str = R"(
341 HloModule test_conditional_module
342
343 true_branch {
344 param.1 = (s32[], s32[]) parameter(0)
345 get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
346 get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1
347 add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
348 ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1)
349 }
350
351 false_branch {
352 param.2 = (s32[], s32[]) parameter(0)
353 get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0
354 get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1
355 add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4)
356 ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4)
357 }
358
359 ENTRY root {
360 param.3 = (pred[], (s32[], s32[])) parameter(0)
361 pred.1 = pred[] get-tuple-element(param.3), index=0
362 cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1
363 conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch
364 cond_res.1 = s32[] get-tuple-element(conditional), index=0
365 cond_res.2 = s32[] get-tuple-element(conditional), index=1
366 add.3 = s32[] add(cond_res.1, cond_res.2)
367 ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2)
368 })";
369
370 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
371 ParseHloString(module_str));
372 TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
373 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
374 DependencyHloOrdering ordering(module.get());
375
376 // Even though the true and false branches has no ordering, since they do not
377 // interfere (as they are mutually exclusive), we define the true computation
378 // to be before the false one.
379 // Similarly, any instruction in the true or false branches are considered
380 // before the conditional instruction. The roots are effectively "at the same
381 // time" WRT the conditional, but they are Phi-ed anyway.
382 HloInstruction* add_1 = FindInstruction(module.get(), "add.1");
383 HloInstruction* add_2 = FindInstruction(module.get(), "add.2");
384 HloInstruction* add_3 = FindInstruction(module.get(), "add.3");
385 HloInstruction* conditional = FindInstruction(module.get(), "conditional");
386 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
387 dataflow->GetValueDefinedAt(add_2)));
388 EXPECT_TRUE(
389 ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
390 dataflow->GetValueDefinedAt(conditional)));
391 EXPECT_TRUE(
392 ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
393 dataflow->GetValueDefinedAt(conditional)));
394 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
395 dataflow->GetValueDefinedAt(add_3)));
396 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
397 dataflow->GetValueDefinedAt(add_3)));
398 }
399
TEST_F(HloOrderingTest,ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot)400 TEST_F(HloOrderingTest,
401 ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
402 // Tests that values live out of the module should interfere with values
403 // defined after the root instruction. That is:
404 //
405 // %param = param(0)
406 // ROOT %root = negate(%param)
407 // %dead = Constant(123.0)
408 //
409 // %root should interfere with %dead.
410 auto module = CreateNewVerifiedModule();
411 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
412
413 auto builder = HloComputation::Builder(TestName());
414 HloInstruction* param = builder.AddInstruction(
415 HloInstruction::CreateParameter(0, scalar_shape, "param"));
416 HloInstruction* root = builder.AddInstruction(
417 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
418 HloInstruction* dead = builder.AddInstruction(
419 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
420 HloComputation* entry =
421 module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
422
423 HloSchedule schedule(module.get());
424 schedule.set_sequence(entry, {param, root, dead});
425 TF_ASSERT_OK(schedule.Verify());
426 SequentialHloOrdering ordering(schedule);
427
428 TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
429 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
430
431 EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
432 EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
433
434 EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
435 dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
436 *dataflow));
437
438 EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
439 dataflow->GetValueDefinedAt(dead),
440 *dataflow));
441 }
442
TEST_F(HloOrderingTest,ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot)443 TEST_F(HloOrderingTest,
444 ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
445 // Tests that values live out of a computation should interfere with values
446 // defined after the root instruction of the computation. That is:
447 //
448 // subcomputation:
449 // %param = param(0)
450 // ROOT %root = negate(%param)
451 // %dead = Constant(123.0)
452 //
453 // entry computation:
454 // %c = constant(42.0)
455 // ROOT %call = call({%c}), subcomputation
456 //
457 // %root should interfere with %dead.
458 auto module = CreateNewVerifiedModule();
459 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
460
461 auto subbuilder = HloComputation::Builder(TestName() + ".sub");
462 HloInstruction* param = subbuilder.AddInstruction(
463 HloInstruction::CreateParameter(0, scalar_shape, "param"));
464 HloInstruction* root = subbuilder.AddInstruction(
465 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
466 HloInstruction* dead = subbuilder.AddInstruction(
467 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
468 HloComputation* subcomputation = module->AddEmbeddedComputation(
469 subbuilder.Build(/*root_instruction=*/root));
470
471 auto builder = HloComputation::Builder(TestName());
472 HloInstruction* c = builder.AddInstruction(
473 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
474 HloInstruction* call = builder.AddInstruction(
475 HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
476 HloComputation* entry = module->AddEntryComputation(builder.Build());
477
478 HloSchedule schedule(module.get());
479 schedule.set_sequence(subcomputation, {param, root, dead});
480 schedule.set_sequence(entry, {c, call});
481 TF_ASSERT_OK(schedule.Verify());
482 SequentialHloOrdering ordering(schedule);
483
484 TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
485 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
486
487 EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
488 EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
489
490 EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
491 dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
492 *dataflow));
493
494 EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
495 dataflow->GetValueDefinedAt(dead),
496 *dataflow));
497 }
498
499 } // namespace
500 } // namespace xla
501