1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace xla {
31 namespace {
32 
33 class HloLivenessAnalysisTest : public HloTestBase {
34  protected:
HloLivenessAnalysisTest()35   HloLivenessAnalysisTest() {}
36 
37   // Run liveness analysis on the member module. For convenience returns a
38   // reference to the generated analysis stored in analysis_.
RunLiveness(HloModule * module)39   const HloLivenessAnalysis& RunLiveness(HloModule* module) {
40     liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie();
41     return *liveness_;
42   }
43 
GetInstruction(HloModule * module,const string & name)44   HloInstruction* GetInstruction(HloModule* module, const string& name) {
45     HloInstruction* to_return = nullptr;
46     for (auto* comp : module->computations()) {
47       for (auto* inst : comp->instructions()) {
48         if (inst->name() == name) {
49           to_return = inst;
50           break;
51         }
52       }
53     }
54     return CHECK_NOTNULL(to_return);
55   }
56 
57   std::unique_ptr<HloLivenessAnalysis> liveness_;
58 };
59 
60 // Test that add instruction at entry root is live at all output shape indices.
TEST_F(HloLivenessAnalysisTest,AddAtEntryRoot)61 TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
62   auto module = ParseHloString(R"(
63   HloModule SimpleModule
64   ENTRY SimpleComputation {
65     constant.1 = s32[] constant(0)
66     constant.2 = s32[] constant(1)
67     ROOT add = s32[] add(constant.1, constant.2)
68   })")
69                     .ValueOrDie();
70   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
71   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
72   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
73   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
74 }
75 
76 // Test that a dead add instruction is marked as dead by analysis.
TEST_F(HloLivenessAnalysisTest,DeadAdd)77 TEST_F(HloLivenessAnalysisTest, DeadAdd) {
78   auto module = ParseHloString(R"(
79   HloModule SimpleModule
80   ENTRY SimpleComputation {
81     constant.1 = s32[] constant(0)
82     constant.2 = s32[] constant(1)
83     add.1 = s32[] add(constant.1, constant.2)
84     ROOT add.2 = s32[] add(constant.1, constant.2)
85   })")
86                     .ValueOrDie();
87   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
88   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
89   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
90   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
91   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {}));
92 }
93 
94 // Test that all output shape indices of entry root tuple (and defining
95 // instruction in its output) are marked live.
TEST_F(HloLivenessAnalysisTest,TupleAtEntryRoot)96 TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
97   auto module = ParseHloString(R"(
98   HloModule SimpleModule
99   ENTRY SimpleComputation {
100     constant.1 = s32[] constant(0)
101     constant.2 = s32[] constant(1)
102     ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2)
103   })")
104                     .ValueOrDie();
105   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
106   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
107   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
108   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
109   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
110   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
111 }
112 
113 // Tests that all outputs of nested tuple and entry root (and defining
114 // instruction values appearing in its output) are marked live.
TEST_F(HloLivenessAnalysisTest,NestedTupleAtEntryRoot)115 TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
116   auto module = ParseHloString(R"(
117   HloModule SimpleModule
118   ENTRY SimpleComputation {
119     constant.1 = s32[] constant(1)
120     constant.2 = s32[] constant(2)
121     constant.3 = s32[] constant(3)
122     tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
123     ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
124   })")
125                     .ValueOrDie();
126   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
127   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
128   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
129   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
130   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
131   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
132   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
133   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
134   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
135   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
136   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
137   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
138 }
139 
140 // Tests that GTE at entry root of Tuple instruction only propgates liveness
141 // to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest,GteOfTuple)142 TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
143   auto module = ParseHloString(R"(
144   HloModule SimpleModule
145   ENTRY SimpleComputation {
146     constant.1 = s32[] constant(0)
147     constant.2 = s32[] constant(1)
148     tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2)
149     ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0
150   })")
151                     .ValueOrDie();
152   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
153   EXPECT_TRUE(
154       liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
155   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
156   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
157   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
158   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
159   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
160 }
161 
162 // Tests that GTE at entry root of nested Tuple instruction only propgates
163 // liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest,GteOfNestedTuple)164 TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
165   auto module = ParseHloString(R"(
166   HloModule SimpleModule
167   ENTRY SimpleComputation {
168     constant.1 = s32[] constant(0)
169     constant.2 = s32[] constant(1)
170     constant.3 = s32[] constant(2)
171     tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
172     tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
173     ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1
174   })")
175                     .ValueOrDie();
176   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
177   EXPECT_TRUE(
178       liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
179   EXPECT_TRUE(liveness.IsLive(
180       GetInstruction(module.get(), "get-tuple-element.1"), {0}));
181   EXPECT_TRUE(liveness.IsLive(
182       GetInstruction(module.get(), "get-tuple-element.1"), {1}));
183 
184   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
185   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
186   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
187   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
188   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
189 
190   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
191   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
192   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
193 
194   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
195   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
196   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
197 }
198 
199 // Tests that GTE of GTE (at entry root) of nested Tuple instruction only
200 // propgates liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest,GteOfGteOfNestedTuple)201 TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
202   auto module = ParseHloString(R"(
203   HloModule SimpleModule
204   ENTRY SimpleComputation {
205     constant.1 = s32[] constant(0)
206     constant.2 = s32[] constant(1)
207     constant.3 = s32[] constant(2)
208     tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
209     tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
210     get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1
211     ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0
212   })")
213                     .ValueOrDie();
214   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
215   EXPECT_TRUE(
216       liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {}));
217 
218   EXPECT_TRUE(
219       liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
220   EXPECT_TRUE(liveness.IsLive(
221       GetInstruction(module.get(), "get-tuple-element.1"), {0}));
222   EXPECT_FALSE(liveness.IsLive(
223       GetInstruction(module.get(), "get-tuple-element.1"), {1}));
224 
225   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
226   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
227   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
228   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
229   EXPECT_FALSE(
230       liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
231 
232   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
233   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
234   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
235 
236   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
237   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
238   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
239 }
240 
241 // Test that live/dead while tuple elements are marked live/dead correctly.
TEST_F(HloLivenessAnalysisTest,WhileWithDeadTupleElement)242 TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
243   auto module = ParseHloString(R"(
244   HloModule SimpleLoop
245   SimpleLoop.body {
246     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
247     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
248     constant.1 = s32[] constant(1)
249     add.0 = s32[] add(get-tuple-element.1, constant.1)
250     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
251     multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
252     ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0)
253   }
254   SimpleLoop.condition {
255     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
256     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
257     constant.2 = s32[] constant(5)
258     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
259   }
260   ENTRY SimpleLoop {
261     constant.3 = s32[] constant(0)
262     constant.4 = s32[3]{0} constant({0, 1, 2})
263     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
264     while.0 = (s32[], s32[3]{0}) while(tuple.1), condition=
265       SimpleLoop.condition, body=SimpleLoop.body
266     ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0
267   })")
268                     .ValueOrDie();
269   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
270   EXPECT_TRUE(
271       liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {}));
272 
273   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {}));
274   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0}));
275   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1}));
276 
277   // While operand.
278   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
279   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
280   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
281   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
282 
283   // While body.
284   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
285   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
286   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
287   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
288   EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
289 }
290 
291 // Tests that a tuple element live in while.cond computation, propagates
292 // liveness to while.body.root/while.result/while.operand (where it is unused).
TEST_F(HloLivenessAnalysisTest,WhileCondPropagatesLiveness)293 TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
294   auto module = ParseHloString(R"(
295   HloModule SimpleLoop
296   SimpleLoop.body {
297     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
298     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
299     constant.1 = s32[] constant(1)
300     add.0 = s32[] add(get-tuple-element.1, constant.1)
301     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
302     multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
303     ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0)
304   }
305   SimpleLoop.condition {
306     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
307     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
308     get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
309     add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
310     constant.2 = s32[] constant(5)
311     ROOT less-than = pred[] compare(add.1, constant.2), direction=LT
312   }
313   ENTRY SimpleLoop {
314     constant.3 = s32[] constant(0)
315     constant.4 = s32[3]{0} constant({0, 1, 2})
316     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
317     while.0 = (s32[], s32[3]{0}) while(tuple.1), condition=
318       SimpleLoop.condition, body=SimpleLoop.body
319     ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0
320   })")
321                     .ValueOrDie();
322   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
323   EXPECT_TRUE(
324       liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {}));
325 
326   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {}));
327   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0}));
328   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1}));
329 
330   // While operand.
331   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
332   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
333   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
334   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
335   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {}));
336 
337   // While body.
338   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
339   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
340   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
341   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
342   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
343 }
344 
345 // Tests that a use of while.result{0} propagates liveness to
346 // while.body.param{1} to while.body.root{1}, and then to while.body.param{2}.
TEST_F(HloLivenessAnalysisTest,WhileWithLiveTupleElements)347 TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
348   auto module = ParseHloString(R"(
349   HloModule SimpleLoop
350   SimpleLoop.body {
351     loop_var.1 = (s32[], s32[], s32[]) parameter(0)
352     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
353     get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1
354     add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
355     get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2
356     multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3)
357     ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1)
358   }
359   SimpleLoop.condition {
360     loop_var.2 = (s32[], s32[], s32[]) parameter(0)
361     get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
362     constant.1 = s32[] constant(5)
363     ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT
364   }
365   ENTRY SimpleLoop {
366     constant.2 = s32[] constant(0)
367     constant.3 = s32[] constant(1)
368     constant.4 = s32[] constant(2)
369     tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4)
370     while.1 = (s32[], s32[], s32[]) while(tuple.2), condition=
371       SimpleLoop.condition, body=SimpleLoop.body
372     ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0
373   })")
374                     .ValueOrDie();
375 
376   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
377   EXPECT_TRUE(
378       liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {}));
379 
380   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {}));
381   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0}));
382   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1}));
383   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2}));
384   // While operand.
385   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
386   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
387   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
388   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2}));
389   // While body root.
390   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
391   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
392   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
393   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2}));
394   // While body param.
395   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {}));
396   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0}));
397   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1}));
398   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
399 }
400 
TEST_F(HloLivenessAnalysisTest,WhileWithOutfeed)401 TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
402   auto module = ParseHloString(R"(
403   HloModule OutfeedLoop
404   WhileBody {
405     body_param = (s32[]) parameter(0)
406     token0 = token[] after-all()
407     constant.2 = s32[] constant(2)
408     outfeed_tuple = (s32[]) outfeed(constant.2, token0)
409     get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
410     constant.1 = s32[] constant(1)
411     add = s32[] add(get-tuple-element.1, constant.1)
412     ROOT tuple = (s32[]) tuple(add)
413   }
414   WhileCondition {
415     cond_param = (s32[]) parameter(0)
416     get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
417     constant.2 = s32[] constant(10)
418     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
419   }
420   ENTRY SimpleLoop {
421     constant.3 = s32[] constant(0)
422     tuple.1 = (s32[]) tuple(constant.3)
423     while = (s32[]) while(tuple.1), condition=WhileCondition,
424       body=WhileBody
425     ROOT rtuple = () tuple()
426   })")
427                     .ValueOrDie();
428 
429   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
430   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
431   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
432 }
433 
TEST_F(HloLivenessAnalysisTest,NestedWhileWithOutfeed)434 TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
435   auto module = ParseHloString(R"(
436   HloModule OutfeedLoop
437   InnerWhileBody {
438     body_param = (s32[]) parameter(0)
439     token0 = token[] after-all()
440     constant.2 = s32[] constant(2)
441     outfeed_tuple = (s32[]) outfeed(constant.2, token0)
442     get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
443     constant.1 = s32[] constant(1)
444     add = s32[] add(get-tuple-element.1, constant.1)
445     ROOT tuple = (s32[]) tuple(add)
446   }
447   InnerWhileCondition {
448     cond_param = (s32[]) parameter(0)
449     get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
450     constant.2 = s32[] constant(10)
451     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
452   }
453   OuterWhileCondition {
454     cond_param.2 = (s32[]) parameter(0)
455     get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
456     constant.5 = s32[] constant(5)
457     ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT
458   }
459   OuterWhileBody {
460     body_param.2 = (s32[]) parameter(0)
461     get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0
462     constant.6 = s32[] constant(0)
463     tuple.2 = (s32[]) tuple(constant.6)
464     inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition,
465       body=InnerWhileBody
466     constant.7 = s32[] constant(1)
467     add.2 = s32[] add(get-tuple-element.8, constant.7)
468     ROOT rtuple = (s32[]) tuple(add.2)
469   }
470   ENTRY SimpleLoop {
471     constant.3 = s32[] constant(0)
472     tuple.1 = (s32[]) tuple(constant.3)
473     while = (s32[]) while(tuple.1), condition=OuterWhileCondition,
474       body=OuterWhileBody
475     ROOT rtuple = () tuple()
476   })")
477                     .ValueOrDie();
478 
479   const HloLivenessAnalysis& liveness = RunLiveness(module.get());
480   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
481   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
482   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
483 }
484 
485 }  // namespace
486 }  // namespace xla
487