1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 
33 namespace xla {
34 namespace {
35 
36 class HloOrderingTest : public HloTestBase {};
37 
TEST_F(HloOrderingTest,InstructionsInDifferentComputations)38 TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
39   // Tests the ordering of instructions in different computations using the
40   // following HLO code:
41   //
42   // Entry computation:
43   //   %x = Call(A, {})
44   //   %y = Call(B, {%x})
45   //
46   // Computation A:
47   //   %a = Call(C, {})
48   //
49   // Computation B:
50   //   %b = Call(C, {})
51   //
52   // Computation C:
53   //   %c = Constant(42.0f)
54   //
55   // This results in a diamond-shaped callgraph.
56   auto module = CreateNewVerifiedModule();
57   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
58 
59   auto builder_c = HloComputation::Builder("C");
60   HloInstruction* c = builder_c.AddInstruction(
61       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
62   HloComputation* computation_c =
63       module->AddEmbeddedComputation(builder_c.Build());
64 
65   auto builder_b = HloComputation::Builder("B");
66   builder_b.AddInstruction(
67       HloInstruction::CreateParameter(0, scalar_shape, "param"));
68   HloInstruction* b = builder_b.AddInstruction(
69       HloInstruction::CreateCall(scalar_shape, {}, computation_c));
70   HloComputation* computation_b =
71       module->AddEmbeddedComputation(builder_b.Build());
72 
73   auto builder_a = HloComputation::Builder("A");
74   HloInstruction* a = builder_a.AddInstruction(
75       HloInstruction::CreateCall(scalar_shape, {}, computation_c));
76   HloComputation* computation_a =
77       module->AddEmbeddedComputation(builder_a.Build());
78 
79   auto builder = HloComputation::Builder(TestName());
80   HloInstruction* x = builder.AddInstruction(
81       HloInstruction::CreateCall(scalar_shape, {}, computation_a));
82   HloInstruction* y = builder.AddInstruction(
83       HloInstruction::CreateCall(scalar_shape, {x}, computation_b));
84   module->AddEntryComputation(builder.Build());
85 
86   DependencyHloOrdering ordering(module.get());
87   EXPECT_TRUE(ordering.ExecutesBefore(x, y));
88   EXPECT_FALSE(ordering.ExecutesBefore(y, x));
89 
90   EXPECT_TRUE(ordering.ExecutesBefore(a, b));
91   EXPECT_FALSE(ordering.ExecutesBefore(b, a));
92 
93   EXPECT_FALSE(ordering.ExecutesBefore(a, x));
94   EXPECT_TRUE(ordering.ExecutesBefore(a, y));
95   EXPECT_FALSE(ordering.ExecutesBefore(x, a));
96   EXPECT_FALSE(ordering.ExecutesBefore(y, a));
97 
98   EXPECT_FALSE(ordering.ExecutesBefore(b, x));
99   EXPECT_FALSE(ordering.ExecutesBefore(b, y));
100   EXPECT_TRUE(ordering.ExecutesBefore(x, b));
101   EXPECT_FALSE(ordering.ExecutesBefore(y, b));
102 
103   // Instruction 'c' is called from multiple callsites and should be unordered
104   // relative to all other instructions in the module.
105   EXPECT_FALSE(ordering.ExecutesBefore(c, a));
106   EXPECT_FALSE(ordering.ExecutesBefore(c, b));
107   EXPECT_FALSE(ordering.ExecutesBefore(c, x));
108   EXPECT_FALSE(ordering.ExecutesBefore(c, y));
109   EXPECT_FALSE(ordering.ExecutesBefore(a, c));
110   EXPECT_FALSE(ordering.ExecutesBefore(b, c));
111   EXPECT_FALSE(ordering.ExecutesBefore(x, c));
112   EXPECT_FALSE(ordering.ExecutesBefore(y, c));
113 }
114 
TEST_F(HloOrderingTest,InstructionsInWhileComputations)115 TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
116   // Tests the ordering of instructions in the body and condition of a while
117   // instruction. HLO code:
118   //
119   // body(F32[]) %param):
120   //   %negate = Negate(%param)
121   //
122   // condition(F32[] %param):
123   //   %convert = Convert<PRED>(%param)
124   //
125   // entry:
126   //   %constant = Constant(1.0)
127   //   return While(%constant, body, condition)
128   //
129   auto module = CreateNewVerifiedModule();
130   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
131 
132   auto body_builder = HloComputation::Builder("body");
133   auto body_param = body_builder.AddInstruction(
134       HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
135   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
136       scalar_shape, HloOpcode::kNegate, body_param));
137   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
138 
139   auto cond_builder = HloComputation::Builder("condition");
140   auto cond_param = cond_builder.AddInstruction(
141       HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
142   auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
143       ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
144   HloComputation* condition =
145       module->AddEmbeddedComputation(cond_builder.Build());
146 
147   auto builder = HloComputation::Builder(TestName());
148   auto constant = builder.AddInstruction(
149       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
150   auto xla_while = builder.AddInstruction(
151       HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
152   module->AddEntryComputation(builder.Build());
153 
154   DependencyHloOrdering ordering(module.get());
155   EXPECT_TRUE(ordering.ExecutesBefore(constant, xla_while));
156   EXPECT_TRUE(ordering.ExecutesBefore(constant, cond_param));
157   EXPECT_TRUE(ordering.ExecutesBefore(constant, convert));
158   EXPECT_TRUE(ordering.ExecutesBefore(constant, body_param));
159   EXPECT_TRUE(ordering.ExecutesBefore(constant, negate));
160 
161   // The while should be unordered relative to the body and condition
162   // instructions.
163   EXPECT_FALSE(ordering.ExecutesBefore(xla_while, body_param));
164   EXPECT_FALSE(ordering.ExecutesBefore(xla_while, cond_param));
165   EXPECT_FALSE(ordering.ExecutesBefore(body_param, xla_while));
166   EXPECT_FALSE(ordering.ExecutesBefore(cond_param, xla_while));
167 
168   // Condition instructions should be ordered before body instructions.
169   EXPECT_TRUE(ordering.ExecutesBefore(cond_param, body_param));
170   EXPECT_TRUE(ordering.ExecutesBefore(convert, body_param));
171   EXPECT_TRUE(ordering.ExecutesBefore(cond_param, negate));
172   EXPECT_TRUE(ordering.ExecutesBefore(convert, negate));
173 
174   EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
175 }
176 
TEST_F(HloOrderingTest,ParametersDefinedBeforeOthers)177 TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) {
178   // Entry parameter should always be defined before other instruction.
179   auto module = CreateNewVerifiedModule();
180   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
181   auto builder = HloComputation::Builder(TestName());
182   auto constant = builder.AddInstruction(
183       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
184   auto param = builder.AddInstruction(
185       HloInstruction::CreateParameter(0, scalar_shape, "param"));
186   module->AddEntryComputation(builder.Build());
187   TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
188                           HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
189 
190   DependencyHloOrdering ordering(module.get());
191   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param),
192                                        dataflow->GetValueDefinedAt(constant)));
193   EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
194                                         dataflow->GetValueDefinedAt(param)));
195 }
196 
TEST_F(HloOrderingTest,ValuesInWhileComputations)197 TEST_F(HloOrderingTest, ValuesInWhileComputations) {
198   // Tests the ordering of values (defined by dataflow analysis) in the body and
199   // condition of a while instruction. HLO code:
200   //
201   // body(F32[]) %param):
202   //   %negate = Negate(%param)
203   //
204   // condition(F32[] %param):
205   //   %convert = Convert<PRED>(%param)
206   //
207   // entry:
208   //   %constant = Constant(1.0)
209   //   %while = While(%constant, body, condition)
210   //   %add = Add(%constant, %while)
211   //
212   auto module = CreateNewVerifiedModule();
213   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
214 
215   auto body_builder = HloComputation::Builder("body");
216   auto body_param = body_builder.AddInstruction(
217       HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
218   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
219       scalar_shape, HloOpcode::kNegate, body_param));
220   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
221 
222   auto cond_builder = HloComputation::Builder("condition");
223   auto cond_param = cond_builder.AddInstruction(
224       HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
225   auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
226       ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
227   HloComputation* condition =
228       module->AddEmbeddedComputation(cond_builder.Build());
229 
230   auto builder = HloComputation::Builder(TestName());
231   auto constant = builder.AddInstruction(
232       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
233   auto xla_while = builder.AddInstruction(
234       HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
235   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
236       scalar_shape, HloOpcode::kAdd, constant, xla_while));
237   module->AddEntryComputation(builder.Build());
238 
239   TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
240                           HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
241   DependencyHloOrdering ordering(module.get());
242 
243   // Init value is defined before the while, but live range is not before the
244   // while because of the use of the init value in the add.
245   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
246                                        dataflow->GetValueDefinedAt(xla_while)));
247   EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
248       dataflow->GetValueDefinedAt(constant),
249       dataflow->GetValueDefinedAt(xla_while), *dataflow));
250   EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
251                                     dataflow->GetValueDefinedAt(xla_while),
252                                     *dataflow));
253 
254   // Any value defined in the body or condition is defined before the while, and
255   // has a live range strictly before the while.
256   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
257                                        dataflow->GetValueDefinedAt(xla_while)));
258   EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
259       dataflow->GetValueDefinedAt(negate),
260       dataflow->GetValueDefinedAt(xla_while), *dataflow));
261   EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
262                                      dataflow->GetValueDefinedAt(xla_while),
263                                      *dataflow));
264 
265   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
266                                        dataflow->GetValueDefinedAt(xla_while)));
267   EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
268       dataflow->GetValueDefinedAt(convert),
269       dataflow->GetValueDefinedAt(xla_while), *dataflow));
270   EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
271                                      dataflow->GetValueDefinedAt(xla_while),
272                                      *dataflow));
273 
274   // The live range of the while should be before the add.
275   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
276                                        dataflow->GetValueDefinedAt(add)));
277   ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
278 
279   const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
280   EXPECT_EQ(while_use.instruction, add);
281   EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
282       while_use, dataflow->GetValueDefinedAt(add), *dataflow));
283   EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
284       dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add),
285       *dataflow));
286 }
287 
288 // Regression test for HloOrdering::ToString() crashing when fed a computation
289 // containing a fusion node.
TEST_F(HloOrderingTest,ToStringDoesNotCrash)290 TEST_F(HloOrderingTest, ToStringDoesNotCrash) {
291   const char* module_str = R"(
292 HloModule test_module
293 
294 body.v8 {
295   prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
296   get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0
297   constant.1 = s32[] constant(1)
298   add = s32[] add(get-tuple-element.4, constant.1)
299   get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3
300   get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1
301   get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2
302   ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7)
303 }
304 
305 condition.v4 {
306   constant.2 = s32[] constant(2)
307   prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
308   get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
309   ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT
310 }
311 
312 fused_computation {
313   get-tuple-element.5.param_1 = f32[3]{0} parameter(1)
314   get-tuple-element.6.param_2 = f32[3]{0} parameter(2)
315   add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2)
316   get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0)
317   ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1)
318 }
319 
320 ENTRY while.v11 {
321   constant.5 = s32[] constant(0)
322   constant.6 = f32[3]{0} constant({1, 1, 1})
323   constant.7 = f32[3]{0} constant({2, 2, 2})
324   constant.8 = f32[3]{0} constant({3, 3, 3})
325   tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8)
326   while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8
327   get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3
328   get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1
329   get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2
330   ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation
331 })";
332 
333   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
334                           ParseHloString(module_str));
335   DependencyHloOrdering ordering(module.get());
336   ordering.ToString();  // Shouldn't crash.
337 }
338 
TEST_F(HloOrderingTest,ConditionalInstructionOrdering)339 TEST_F(HloOrderingTest, ConditionalInstructionOrdering) {
340   const char* module_str = R"(
341 HloModule test_conditional_module
342 
343 true_branch {
344   param.1 = (s32[], s32[]) parameter(0)
345   get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
346   get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1
347   add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
348   ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1)
349 }
350 
351 false_branch {
352   param.2 = (s32[], s32[]) parameter(0)
353   get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0
354   get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1
355   add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4)
356   ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4)
357 }
358 
359 ENTRY root {
360   param.3 = (pred[], (s32[], s32[])) parameter(0)
361   pred.1 = pred[] get-tuple-element(param.3), index=0
362   cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1
363   conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch
364   cond_res.1 = s32[] get-tuple-element(conditional), index=0
365   cond_res.2 = s32[] get-tuple-element(conditional), index=1
366   add.3 = s32[] add(cond_res.1, cond_res.2)
367   ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2)
368 })";
369 
370   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
371                           ParseHloString(module_str));
372   TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
373                           HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
374   DependencyHloOrdering ordering(module.get());
375 
376   // Even though the true and false branches has no ordering, since they do not
377   // interfere (as they are mutually exclusive), we define the true computation
378   // to be before the false one.
379   // Similarly, any instruction in the true or false branches are considered
380   // before the conditional instruction. The roots are effectively "at the same
381   // time" WRT the conditional, but they are Phi-ed anyway.
382   HloInstruction* add_1 = FindInstruction(module.get(), "add.1");
383   HloInstruction* add_2 = FindInstruction(module.get(), "add.2");
384   HloInstruction* add_3 = FindInstruction(module.get(), "add.3");
385   HloInstruction* conditional = FindInstruction(module.get(), "conditional");
386   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
387                                        dataflow->GetValueDefinedAt(add_2)));
388   EXPECT_TRUE(
389       ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
390                                dataflow->GetValueDefinedAt(conditional)));
391   EXPECT_TRUE(
392       ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
393                                dataflow->GetValueDefinedAt(conditional)));
394   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
395                                        dataflow->GetValueDefinedAt(add_3)));
396   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
397                                        dataflow->GetValueDefinedAt(add_3)));
398 }
399 
TEST_F(HloOrderingTest,ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot)400 TEST_F(HloOrderingTest,
401        ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
402   // Tests that values live out of the module should interfere with values
403   // defined after the root instruction. That is:
404   //
405   //   %param = param(0)
406   //   ROOT %root = negate(%param)
407   //   %dead = Constant(123.0)
408   //
409   // %root should interfere with %dead.
410   auto module = CreateNewVerifiedModule();
411   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
412 
413   auto builder = HloComputation::Builder(TestName());
414   HloInstruction* param = builder.AddInstruction(
415       HloInstruction::CreateParameter(0, scalar_shape, "param"));
416   HloInstruction* root = builder.AddInstruction(
417       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
418   HloInstruction* dead = builder.AddInstruction(
419       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
420   HloComputation* entry =
421       module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
422 
423   HloSchedule schedule(module.get());
424   schedule.set_sequence(entry, {param, root, dead});
425   TF_ASSERT_OK(schedule.Verify());
426   SequentialHloOrdering ordering(schedule);
427 
428   TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
429                           HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
430 
431   EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
432   EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
433 
434   EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
435       dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
436       *dataflow));
437 
438   EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
439                                     dataflow->GetValueDefinedAt(dead),
440                                     *dataflow));
441 }
442 
TEST_F(HloOrderingTest,ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot)443 TEST_F(HloOrderingTest,
444        ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
445   // Tests that values live out of a computation should interfere with values
446   // defined after the root instruction of the computation. That is:
447   //
448   // subcomputation:
449   //   %param = param(0)
450   //   ROOT %root = negate(%param)
451   //   %dead = Constant(123.0)
452   //
453   // entry computation:
454   //   %c = constant(42.0)
455   //   ROOT %call = call({%c}), subcomputation
456   //
457   // %root should interfere with %dead.
458   auto module = CreateNewVerifiedModule();
459   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
460 
461   auto subbuilder = HloComputation::Builder(TestName() + ".sub");
462   HloInstruction* param = subbuilder.AddInstruction(
463       HloInstruction::CreateParameter(0, scalar_shape, "param"));
464   HloInstruction* root = subbuilder.AddInstruction(
465       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
466   HloInstruction* dead = subbuilder.AddInstruction(
467       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
468   HloComputation* subcomputation = module->AddEmbeddedComputation(
469       subbuilder.Build(/*root_instruction=*/root));
470 
471   auto builder = HloComputation::Builder(TestName());
472   HloInstruction* c = builder.AddInstruction(
473       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
474   HloInstruction* call = builder.AddInstruction(
475       HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
476   HloComputation* entry = module->AddEntryComputation(builder.Build());
477 
478   HloSchedule schedule(module.get());
479   schedule.set_sequence(subcomputation, {param, root, dead});
480   schedule.set_sequence(entry, {c, call});
481   TF_ASSERT_OK(schedule.Verify());
482   SequentialHloOrdering ordering(schedule);
483 
484   TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
485                           HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
486 
487   EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
488   EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
489 
490   EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
491       dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
492       *dataflow));
493 
494   EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
495                                     dataflow->GetValueDefinedAt(dead),
496                                     *dataflow));
497 }
498 
499 }  // namespace
500 }  // namespace xla
501