1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace xla {
31 namespace {
32
33 class HloLivenessAnalysisTest : public HloTestBase {
34 protected:
HloLivenessAnalysisTest()35 HloLivenessAnalysisTest() {}
36
37 // Run liveness analysis on the member module. For convenience returns a
38 // reference to the generated analysis stored in analysis_.
RunLiveness(HloModule * module)39 const HloLivenessAnalysis& RunLiveness(HloModule* module) {
40 liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie();
41 return *liveness_;
42 }
43
GetInstruction(HloModule * module,const string & name)44 HloInstruction* GetInstruction(HloModule* module, const string& name) {
45 HloInstruction* to_return = nullptr;
46 for (auto* comp : module->computations()) {
47 for (auto* inst : comp->instructions()) {
48 if (inst->name() == name) {
49 to_return = inst;
50 break;
51 }
52 }
53 }
54 return CHECK_NOTNULL(to_return);
55 }
56
57 std::unique_ptr<HloLivenessAnalysis> liveness_;
58 };
59
60 // Test that add instruction at entry root is live at all output shape indices.
TEST_F(HloLivenessAnalysisTest,AddAtEntryRoot)61 TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
62 auto module = ParseHloString(R"(
63 HloModule SimpleModule
64 ENTRY SimpleComputation {
65 constant.1 = s32[] constant(0)
66 constant.2 = s32[] constant(1)
67 ROOT add = s32[] add(constant.1, constant.2)
68 })")
69 .ValueOrDie();
70 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
71 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
72 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
73 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
74 }
75
76 // Test that a dead add instruction is marked as dead by analysis.
TEST_F(HloLivenessAnalysisTest,DeadAdd)77 TEST_F(HloLivenessAnalysisTest, DeadAdd) {
78 auto module = ParseHloString(R"(
79 HloModule SimpleModule
80 ENTRY SimpleComputation {
81 constant.1 = s32[] constant(0)
82 constant.2 = s32[] constant(1)
83 add.1 = s32[] add(constant.1, constant.2)
84 ROOT add.2 = s32[] add(constant.1, constant.2)
85 })")
86 .ValueOrDie();
87 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
88 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
89 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
90 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
91 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {}));
92 }
93
94 // Test that all output shape indices of entry root tuple (and defining
95 // instruction in its output) are marked live.
TEST_F(HloLivenessAnalysisTest,TupleAtEntryRoot)96 TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
97 auto module = ParseHloString(R"(
98 HloModule SimpleModule
99 ENTRY SimpleComputation {
100 constant.1 = s32[] constant(0)
101 constant.2 = s32[] constant(1)
102 ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2)
103 })")
104 .ValueOrDie();
105 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
106 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
107 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
108 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
109 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
110 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
111 }
112
113 // Tests that all outputs of nested tuple and entry root (and defining
114 // instruction values appearing in its output) are marked live.
TEST_F(HloLivenessAnalysisTest,NestedTupleAtEntryRoot)115 TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
116 auto module = ParseHloString(R"(
117 HloModule SimpleModule
118 ENTRY SimpleComputation {
119 constant.1 = s32[] constant(1)
120 constant.2 = s32[] constant(2)
121 constant.3 = s32[] constant(3)
122 tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
123 ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
124 })")
125 .ValueOrDie();
126 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
127 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
128 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
129 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
130 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
131 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
132 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
133 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
134 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
135 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
136 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
137 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
138 }
139
140 // Tests that GTE at entry root of Tuple instruction only propgates liveness
141 // to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest,GteOfTuple)142 TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
143 auto module = ParseHloString(R"(
144 HloModule SimpleModule
145 ENTRY SimpleComputation {
146 constant.1 = s32[] constant(0)
147 constant.2 = s32[] constant(1)
148 tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2)
149 ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0
150 })")
151 .ValueOrDie();
152 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
153 EXPECT_TRUE(
154 liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
155 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
156 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
157 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
158 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
159 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
160 }
161
162 // Tests that GTE at entry root of nested Tuple instruction only propgates
163 // liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest,GteOfNestedTuple)164 TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
165 auto module = ParseHloString(R"(
166 HloModule SimpleModule
167 ENTRY SimpleComputation {
168 constant.1 = s32[] constant(0)
169 constant.2 = s32[] constant(1)
170 constant.3 = s32[] constant(2)
171 tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
172 tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
173 ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1
174 })")
175 .ValueOrDie();
176 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
177 EXPECT_TRUE(
178 liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
179 EXPECT_TRUE(liveness.IsLive(
180 GetInstruction(module.get(), "get-tuple-element.1"), {0}));
181 EXPECT_TRUE(liveness.IsLive(
182 GetInstruction(module.get(), "get-tuple-element.1"), {1}));
183
184 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
185 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
186 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
187 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
188 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
189
190 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
191 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
192 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
193
194 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
195 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
196 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
197 }
198
199 // Tests that GTE of GTE (at entry root) of nested Tuple instruction only
200 // propgates liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest,GteOfGteOfNestedTuple)201 TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
202 auto module = ParseHloString(R"(
203 HloModule SimpleModule
204 ENTRY SimpleComputation {
205 constant.1 = s32[] constant(0)
206 constant.2 = s32[] constant(1)
207 constant.3 = s32[] constant(2)
208 tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
209 tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
210 get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1
211 ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0
212 })")
213 .ValueOrDie();
214 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
215 EXPECT_TRUE(
216 liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {}));
217
218 EXPECT_TRUE(
219 liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
220 EXPECT_TRUE(liveness.IsLive(
221 GetInstruction(module.get(), "get-tuple-element.1"), {0}));
222 EXPECT_FALSE(liveness.IsLive(
223 GetInstruction(module.get(), "get-tuple-element.1"), {1}));
224
225 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
226 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
227 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
228 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
229 EXPECT_FALSE(
230 liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
231
232 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
233 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
234 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
235
236 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
237 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
238 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
239 }
240
241 // Test that live/dead while tuple elements are marked live/dead correctly.
TEST_F(HloLivenessAnalysisTest,WhileWithDeadTupleElement)242 TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
243 auto module = ParseHloString(R"(
244 HloModule SimpleLoop
245 SimpleLoop.body {
246 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
247 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
248 constant.1 = s32[] constant(1)
249 add.0 = s32[] add(get-tuple-element.1, constant.1)
250 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
251 multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
252 ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0)
253 }
254 SimpleLoop.condition {
255 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
256 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
257 constant.2 = s32[] constant(5)
258 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
259 }
260 ENTRY SimpleLoop {
261 constant.3 = s32[] constant(0)
262 constant.4 = s32[3]{0} constant({0, 1, 2})
263 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
264 while.0 = (s32[], s32[3]{0}) while(tuple.1), condition=
265 SimpleLoop.condition, body=SimpleLoop.body
266 ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0
267 })")
268 .ValueOrDie();
269 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
270 EXPECT_TRUE(
271 liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {}));
272
273 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {}));
274 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0}));
275 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1}));
276
277 // While operand.
278 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
279 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
280 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
281 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
282
283 // While body.
284 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
285 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
286 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
287 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
288 EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
289 }
290
291 // Tests that a tuple element live in while.cond computation, propagates
292 // liveness to while.body.root/while.result/while.operand (where it is unused).
TEST_F(HloLivenessAnalysisTest,WhileCondPropagatesLiveness)293 TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
294 auto module = ParseHloString(R"(
295 HloModule SimpleLoop
296 SimpleLoop.body {
297 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
298 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
299 constant.1 = s32[] constant(1)
300 add.0 = s32[] add(get-tuple-element.1, constant.1)
301 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
302 multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
303 ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0)
304 }
305 SimpleLoop.condition {
306 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
307 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
308 get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
309 add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
310 constant.2 = s32[] constant(5)
311 ROOT less-than = pred[] compare(add.1, constant.2), direction=LT
312 }
313 ENTRY SimpleLoop {
314 constant.3 = s32[] constant(0)
315 constant.4 = s32[3]{0} constant({0, 1, 2})
316 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
317 while.0 = (s32[], s32[3]{0}) while(tuple.1), condition=
318 SimpleLoop.condition, body=SimpleLoop.body
319 ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0
320 })")
321 .ValueOrDie();
322 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
323 EXPECT_TRUE(
324 liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {}));
325
326 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {}));
327 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0}));
328 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1}));
329
330 // While operand.
331 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
332 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
333 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
334 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
335 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {}));
336
337 // While body.
338 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
339 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
340 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
341 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
342 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
343 }
344
345 // Tests that a use of while.result{0} propagates liveness to
346 // while.body.param{1} to while.body.root{1}, and then to while.body.param{2}.
TEST_F(HloLivenessAnalysisTest,WhileWithLiveTupleElements)347 TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
348 auto module = ParseHloString(R"(
349 HloModule SimpleLoop
350 SimpleLoop.body {
351 loop_var.1 = (s32[], s32[], s32[]) parameter(0)
352 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
353 get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1
354 add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
355 get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2
356 multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3)
357 ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1)
358 }
359 SimpleLoop.condition {
360 loop_var.2 = (s32[], s32[], s32[]) parameter(0)
361 get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
362 constant.1 = s32[] constant(5)
363 ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT
364 }
365 ENTRY SimpleLoop {
366 constant.2 = s32[] constant(0)
367 constant.3 = s32[] constant(1)
368 constant.4 = s32[] constant(2)
369 tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4)
370 while.1 = (s32[], s32[], s32[]) while(tuple.2), condition=
371 SimpleLoop.condition, body=SimpleLoop.body
372 ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0
373 })")
374 .ValueOrDie();
375
376 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
377 EXPECT_TRUE(
378 liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {}));
379
380 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {}));
381 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0}));
382 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1}));
383 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2}));
384 // While operand.
385 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
386 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
387 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
388 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2}));
389 // While body root.
390 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
391 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
392 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
393 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2}));
394 // While body param.
395 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {}));
396 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0}));
397 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1}));
398 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
399 }
400
TEST_F(HloLivenessAnalysisTest,WhileWithOutfeed)401 TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
402 auto module = ParseHloString(R"(
403 HloModule OutfeedLoop
404 WhileBody {
405 body_param = (s32[]) parameter(0)
406 token0 = token[] after-all()
407 constant.2 = s32[] constant(2)
408 outfeed_tuple = (s32[]) outfeed(constant.2, token0)
409 get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
410 constant.1 = s32[] constant(1)
411 add = s32[] add(get-tuple-element.1, constant.1)
412 ROOT tuple = (s32[]) tuple(add)
413 }
414 WhileCondition {
415 cond_param = (s32[]) parameter(0)
416 get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
417 constant.2 = s32[] constant(10)
418 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
419 }
420 ENTRY SimpleLoop {
421 constant.3 = s32[] constant(0)
422 tuple.1 = (s32[]) tuple(constant.3)
423 while = (s32[]) while(tuple.1), condition=WhileCondition,
424 body=WhileBody
425 ROOT rtuple = () tuple()
426 })")
427 .ValueOrDie();
428
429 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
430 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
431 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
432 }
433
TEST_F(HloLivenessAnalysisTest,NestedWhileWithOutfeed)434 TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
435 auto module = ParseHloString(R"(
436 HloModule OutfeedLoop
437 InnerWhileBody {
438 body_param = (s32[]) parameter(0)
439 token0 = token[] after-all()
440 constant.2 = s32[] constant(2)
441 outfeed_tuple = (s32[]) outfeed(constant.2, token0)
442 get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
443 constant.1 = s32[] constant(1)
444 add = s32[] add(get-tuple-element.1, constant.1)
445 ROOT tuple = (s32[]) tuple(add)
446 }
447 InnerWhileCondition {
448 cond_param = (s32[]) parameter(0)
449 get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
450 constant.2 = s32[] constant(10)
451 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
452 }
453 OuterWhileCondition {
454 cond_param.2 = (s32[]) parameter(0)
455 get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
456 constant.5 = s32[] constant(5)
457 ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT
458 }
459 OuterWhileBody {
460 body_param.2 = (s32[]) parameter(0)
461 get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0
462 constant.6 = s32[] constant(0)
463 tuple.2 = (s32[]) tuple(constant.6)
464 inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition,
465 body=InnerWhileBody
466 constant.7 = s32[] constant(1)
467 add.2 = s32[] add(get-tuple-element.8, constant.7)
468 ROOT rtuple = (s32[]) tuple(add.2)
469 }
470 ENTRY SimpleLoop {
471 constant.3 = s32[] constant(0)
472 tuple.1 = (s32[]) tuple(constant.3)
473 while = (s32[]) while(tuple.1), condition=OuterWhileCondition,
474 body=OuterWhileBody
475 ROOT rtuple = () tuple()
476 })")
477 .ValueOrDie();
478
479 const HloLivenessAnalysis& liveness = RunLiveness(module.get());
480 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
481 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
482 EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
483 }
484
485 } // namespace
486 } // namespace xla
487