1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_replace.h"
20 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
21 #include "tensorflow/compiler/xla/service/hlo_cse.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30
31 namespace xla {
32 namespace {
33
34 using ::testing::_;
35 namespace op = xla::testing::opcode_matchers;
36
37 // Returns the first kWhile instruction within m's entry computation.
FindFirstWhile(HloModule * m)38 HloInstruction* FindFirstWhile(HloModule* m) {
39 const auto& instrs = m->entry_computation()->instructions();
40 return *absl::c_find_if(instrs, [](const HloInstruction* instr) {
41 return instr->opcode() == HloOpcode::kWhile;
42 });
43 }
44
45 class WhileLoopSimplifierTest : public HloTestBase {
46 protected:
47 // Makes an HloModule that contains a loop with `num_iters` iteration.
48 TF_MUST_USE_RESULT std::unique_ptr<VerifiedHloModule>
49 MakeModuleWithSimpleLoop(int num_iters);
50
51 // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to
52 // the loop-condition through an element of a tuple which is the
53 // loop-condition parameter.
54 TF_MUST_USE_RESULT std::unique_ptr<VerifiedHloModule>
55 MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters);
56 };
57
58 std::unique_ptr<VerifiedHloModule>
MakeModuleWithSimpleLoop(int num_iters)59 WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
60 string hlo_string_template = R"(
61 HloModule SimpleLoop
62 SimpleLoop.body {
63 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
64 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
65 constant.1 = s32[] constant(1)
66 add = s32[] add(get-tuple-element.1, constant.1)
67 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
68 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
69 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
70 }
71 SimpleLoop.condition {
72 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
73 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
74 constant.2 = s32[] constant({{LOOP_BOUND}})
75 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
76 }
77 ENTRY SimpleLoop {
78 constant.3 = s32[] constant(42)
79 constant.4 = s32[3]{0} constant({0, 1, 2})
80 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
81 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
82 SimpleLoop.condition, body=SimpleLoop.body
83 }
84 )";
85
86 string hlo_string = absl::StrReplaceAll(
87 hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
88 return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
89 }
90
91 std::unique_ptr<VerifiedHloModule>
MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters)92 WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
93 int num_iters) {
94 string hlo_string_template = R"(
95 HloModule SimpleLoopWithIndirectLoopBound
96 SimpleLoopWithIndirectLoopBound.body {
97 loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0)
98 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
99 constant.1 = s32[] constant(1)
100 add = s32[] add(get-tuple-element.1, constant.1)
101 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
102 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
103 limit = s32[] get-tuple-element(loop_var.1), index=2
104 ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit)
105 }
106 SimpleLoopWithIndirectLoopBound.condition {
107 loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
108 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
109 get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
110 ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT
111 }
112 ENTRY SimpleLoopWithIndirectLoopBound {
113 constant.3 = s32[] constant(42)
114 constant.4 = s32[3]{0} constant({0, 1, 2})
115 constant.2 = s32[] constant({{LOOP_BOUND}})
116 tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4,
117 constant.2)
118 ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1),
119 condition=SimpleLoopWithIndirectLoopBound.condition,
120 body=SimpleLoopWithIndirectLoopBound.body
121 }
122 )";
123
124 string hlo_string = absl::StrReplaceAll(
125 hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
126 return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
127 }
128
TEST_F(WhileLoopSimplifierTest,LoopWithZeroIterationSimiplified)129 TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
130 auto m = MakeModuleWithSimpleLoop(/*num_iters=*/0);
131 ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
132 EXPECT_THAT(m->entry_computation()->root_instruction(),
133 op::Tuple(op::Constant(), op::Constant()));
134 }
135
TEST_F(WhileLoopSimplifierTest,LoopWithZeroIterationTupleElementLoopBoundSimplified)136 TEST_F(WhileLoopSimplifierTest,
137 LoopWithZeroIterationTupleElementLoopBoundSimplified) {
138 auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0);
139 ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
140 EXPECT_THAT(m->entry_computation()->root_instruction(),
141 op::Tuple(op::Constant(), op::Constant(), op::Constant()));
142 }
143
TEST_F(WhileLoopSimplifierTest,LoopWithOneIterationSimplified)144 TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
145 auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
146 ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
147 EXPECT_THAT(m->entry_computation()->root_instruction(),
148 op::Tuple(op::Add(), op::Multiply()));
149 }
150
TEST_F(WhileLoopSimplifierTest,LoopWithOneIterationTupleELementLoopBoundSimplified)151 TEST_F(WhileLoopSimplifierTest,
152 LoopWithOneIterationTupleELementLoopBoundSimplified) {
153 auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1);
154 ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
155 EXPECT_THAT(m->entry_computation()->root_instruction(),
156 op::Tuple(op::Add(), op::Multiply(), op::Constant()));
157 }
158
TEST_F(WhileLoopSimplifierTest,LoopWithTwoIterationsNotSimplified)159 TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) {
160 auto m = MakeModuleWithSimpleLoop(/*num_iters=*/2);
161 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
162 }
163
TEST_F(WhileLoopSimplifierTest,LoopWithControlDependencySimplifiedDependencyPreserved)164 TEST_F(WhileLoopSimplifierTest,
165 LoopWithControlDependencySimplifiedDependencyPreserved) {
166 auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
167 HloComputation* computation = m->entry_computation();
168 auto* while_op = computation->root_instruction();
169 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
170 auto* true_op = while_op->while_body()->AddInstruction(
171 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
172 TF_ASSERT_OK(true_op->AddControlDependencyTo(
173 while_op->while_body()->root_instruction()));
174 ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
175 EXPECT_THAT(computation->root_instruction()->control_predecessors(),
176 ElementsAre(op::Constant()))
177 << computation->ToString();
178 }
179
180 // Loops that contain send/recv nodes can't be simplified; the loop structure
181 // around send/recv nodes must be preserved.
TEST_F(WhileLoopSimplifierTest,LoopWithSendNotSimplified)182 TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) {
183 auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
184 HloComputation* computation = m->entry_computation();
185 auto* while_op = computation->root_instruction();
186 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
187 auto* while_body = while_op->while_body();
188 auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
189 auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
190 while_body->AddInstruction(
191 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
192 token,
193 /*channel_id=*/0));
194 while_body->AddInstruction(HloInstruction::CreateSendDone(send));
195 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
196 }
197
TEST_F(WhileLoopSimplifierTest,LoopWithRecvNotSimplified)198 TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
199 auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
200 HloComputation* computation = m->entry_computation();
201 auto* while_op = computation->root_instruction();
202 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
203 auto* while_body = while_op->while_body();
204 auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
205 auto* recv = while_body->AddInstruction(
206 HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token,
207 /*channel_id=*/0));
208 while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
209 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
210 }
211
212 // The limitation on not being able to simplify loops that contain infeeds (and
213 // other non-removable instructions) isn't fundamental -- it just stems from the
214 // fact that our infrastructure sees simplifying such a loop as tantamount to
215 // removing the non-removable instruction.
TEST_F(WhileLoopSimplifierTest,LoopWithInfeedNotSimplified)216 TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) {
217 auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
218 HloComputation* computation = m->entry_computation();
219 auto* while_op = computation->root_instruction();
220 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
221 auto* while_body = while_op->while_body();
222 auto token = while_body->AddInstruction(HloInstruction::CreateToken());
223 while_body->AddInstruction(HloInstruction::CreateInfeed(
224 ShapeUtil::MakeShape(F32, {1}), token, "config"));
225 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
226 }
227
228 // A non-tuple shaped loop shouldn't be simplified or crash the compiler.
TEST_F(WhileLoopSimplifierTest,NonTupleShapedLoopNotSimplified)229 TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) {
230 const string hlo_string = R"(
231 HloModule NonTupleShapedLoop
232 NonTupleShapedLoop.body {
233 loop_var.1 = s32[] parameter(0)
234 constant.1 = s32[] constant(-1)
235 ROOT add = s32[] add(s32[] loop_var.1, s32[] constant.1)
236 }
237 NonTupleShapedLoop.condition {
238 loop_var = s32[] parameter(0)
239 constant = s32[] constant(100)
240 ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT
241 }
242 ENTRY INonTupleShapedLoop {
243 constant.2 = s32[] constant(42)
244 ROOT while = s32[] while(s32[] constant.2),
245 condition=NonTupleShapedLoop.condition,
246 body=NonTupleShapedLoop.body
247 }
248 )";
249
250 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
251 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
252 }
253
254 // A while loop that does nothing else besides swapping tuple elements
255 // can't be simplified as the result of the swapping is visible to users of the
256 // loop.
TEST_F(WhileLoopSimplifierTest,LoopSwappingTupleElementsNotSimplified)257 TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) {
258 const string hlo_string = R"(
259 HloModule SwappingTupleElements
260 SwappingTupleElements.body {
261 loop_var = (s32[], s32[]) parameter(0)
262 get-tuple-element = s32[] get-tuple-element((s32[], s32[]) loop_var),index=1
263 get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[]) loop_var),
264 index=0
265 ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element,
266 s32[] get-tuple-element.1)
267 }
268 SwappingTupleElements.always_true {
269 param = (s32[], s32[]) parameter(0)
270 ROOT constant = pred[] constant(true)
271 }
272 ENTRY SwappingTupleElements {
273 x = s32[] parameter(0)
274 y = s32[] parameter(1)
275 tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y)
276 ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1),
277 condition=SwappingTupleElements.always_true,
278 body=SwappingTupleElements.body
279 }
280 )";
281
282 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
283 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
284 }
285
286 // Construct a loop where we assign a constant to tuple element 0 in each
287 // iteration. We can't eliminate tuple element 0, even though we never use its
288 // value.
TEST_F(WhileLoopSimplifierTest,LoopWithUnusedButModifiedTupleElementNotSimplified)289 TEST_F(WhileLoopSimplifierTest,
290 LoopWithUnusedButModifiedTupleElementNotSimplified) {
291 const string hlo_string = R"(
292 HloModule UnusedButModifiedTupleElement
293 UnusedButModifiedTupleElement.body {
294 loop_var = (s32[]) parameter(0)
295 constant.1 = s32[] constant(1)
296 ROOT tuple = (s32[]) tuple(s32[] constant.1)
297 }
298 UnusedButModifiedTupleElement.always_true {
299 param = (s32[]) parameter(0)
300 ROOT constant = pred[] constant(true)
301 }
302 ENTRY UnusedButModifiedTupleElement {
303 constant.2 = s32[] constant(0)
304 tuple.1 = (s32[]) tuple(s32[] constant.2)
305 ROOT while = (s32[]) while((s32[]) tuple.1),
306 condition=UnusedButModifiedTupleElement.always_true,
307 body=UnusedButModifiedTupleElement.body
308 }
309 )";
310
311 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
312 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
313 }
314
315 // Nothing to simplify in a while loop whose tuple has 0 elements.
TEST_F(WhileLoopSimplifierTest,LoopWithEmptyTupleNotSimplified)316 TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) {
317 const string hlo_string = R"(
318 HloModule EmptyTuple
319 EmptyTuple.body {
320 loop_var = () parameter(0)
321 ROOT tuple = () tuple()
322 }
323 EmptyTuple.always_true {
324 param = () parameter(0)
325 ROOT constant = pred[] constant(true)
326 }
327 ENTRY EmptyTuple {
328 tuple.1 = () tuple()
329 ROOT while = () while(() tuple.1), condition=EmptyTuple.always_true,
330 body=EmptyTuple.body
331 }
332 )";
333
334 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
335 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
336 }
337
338 // While loop where one tuple element is used twice in the body, and thus can't
339 // be simplified away.
TEST_F(WhileLoopSimplifierTest,LoopWithElemUsedTwiceNotSimplified)340 TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) {
341 const string hlo_string = R"(
342 HloModule ElemUsedTwice
343 ElemUsedTwice.body {
344 param0 = (s32[], s32[]) parameter(0)
345 get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param0), index=0
346 ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element,
347 s32[] get-tuple-element)
348 }
349 ElemUsedTwice.always_true {
350 param = (s32[], s32[]) parameter(0)
351 ROOT constant = pred[] constant(true)
352 }
353 ENTRY ElemUsedTwice {
354 x = s32[] parameter(0)
355 y = s32[] parameter(1)
356 tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y)
357 ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1),
358 condition=ElemUsedTwice.always_true, body=ElemUsedTwice.body
359 }
360 )";
361
362 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
363 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
364 }
365
366 // This while loop has three tuple elements. Element 0 is unused and should be
367 // removed. Element 1 is used by the loop body, and element 2 is used by the
368 // loop condition; these two should stay.
TEST_F(WhileLoopSimplifierTest,RemoveUnusedLoopOperands)369 TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
370 const string hlo_string = R"(
371 HloModule RemoveUnusedOperands
372 RemoveUnusedOperands.body {
373 loop_var = (s32[], s32[], s32[]) parameter(0)
374 get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[],
375 s32[]) loop_var), index=0
376 get-tuple-element.2 = s32[] get-tuple-element((s32[], s32[],
377 s32[]) loop_var), index=1
378 constant.1 = s32[] constant(1)
379 add = s32[] add(s32[] get-tuple-element.2, s32[] constant.1)
380 get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[])
381 loop_var), index=2
382 ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element.1,
383 s32[] add, s32[] get-tuple-element.3)
384 }
385 RemoveUnusedOperands.loop_condition {
386 constant.2 = s32[] constant(0)
387 param0 = (s32[], s32[], s32[]) parameter(0)
388 get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0),
389 index=2
390 ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ
391 }
392 ENTRY RemoveUnusedOperands {
393 x = s32[] parameter(0)
394 constant.3 = s32[] constant(0)
395 y = s32[] parameter(1)
396 tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] constant.3,
397 s32[] y)
398 ROOT while = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) tuple.1),
399 condition=RemoveUnusedOperands.loop_condition,
400 body=RemoveUnusedOperands.body
401 }
402 )";
403
404 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
405 EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
406
407 // The original while instruction is still left in the module as a dead
408 // instruction, find a while instruction with a different name as the new
409 // while instruction.
410 const auto& instrs = m->entry_computation()->instructions();
411 HloInstruction* new_while_op =
412 *absl::c_find_if(instrs, [&](const HloInstruction* instr) {
413 return (instr->opcode() == HloOpcode::kWhile &&
414 instr->name() != "while");
415 });
416
417 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
418 EXPECT_TRUE(
419 ShapeUtil::Equal(new_while_op->shape(),
420 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32})))
421 << ShapeUtil::HumanString(new_while_op->shape());
422 EXPECT_THAT(
423 new_while_op->while_body()->root_instruction(),
424 op::Tuple(
425 op::Add(op::GetTupleElement(op::Parameter(0), /*tuple_index=*/0),
426 op::Constant()),
427 op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
428
429 EXPECT_THAT(new_while_op->while_condition()->root_instruction(),
430 op::Eq(op::Constant(),
431 op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
432 }
433
TEST_F(WhileLoopSimplifierTest,LoopWithNonTupleBodyShapeNotSimplified)434 TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
435 const string hlo_string = R"(
436 HloModule BodyHasNonTupleRoot
437 BodyHasNonTupleRoot.passthrough {
438 ROOT param = (s32[], s32[]) parameter(0)
439 }
440 BodyHasNonTupleRoot.always_true {
441 param.1 = (s32[], s32[]) parameter(0)
442 ROOT constant = pred[] constant(true)
443 }
444 ENTRY BodyHasNonTupleRoot {
445 init_value = (s32[], s32[]) parameter(0)
446 ROOT while = (s32[], s32[]) while((s32[], s32[]) init_value),
447 condition=BodyHasNonTupleRoot.always_true,
448 body=BodyHasNonTupleRoot.passthrough
449 }
450 )";
451
452 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
453 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
454 }
455
TEST_F(WhileLoopSimplifierTest,LoopWithNonTupleBodyRootInstructionNotSimplified)456 TEST_F(WhileLoopSimplifierTest,
457 LoopWithNonTupleBodyRootInstructionNotSimplified) {
458 const string hlo_string = R"(
459 HloModule SimpleLoop
460 SimpleLoop.body {
461 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
462 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
463 constant.1 = s32[] constant(1)
464 add = s32[] add(get-tuple-element.1, constant.1)
465 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
466 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
467 ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply),
468 custom_call_target="x"
469 }
470 SimpleLoop.condition {
471 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
472 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
473 constant.2 = s32[] constant(44)
474 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
475 }
476 ENTRY SimpleLoop {
477 constant.3 = s32[] constant(42)
478 constant.4 = s32[3]{0} constant({0, 1, 2})
479 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
480 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
481 SimpleLoop.condition, body=SimpleLoop.body
482 }
483 )";
484
485 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
486 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
487 }
488
TEST_F(WhileLoopSimplifierTest,LoopWithArrayConstantNotSimplified)489 TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) {
490 const string hlo_string = R"(
491 HloModule SimpleLoop
492 SimpleLoop.body {
493 loop_var.1 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
494 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
495 constant.1 = s32[] constant(1)
496 add = s32[] add(get-tuple-element.1, constant.1)
497 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
498 get-tuple-element.3 = s32[3]{0} get-tuple-element(loop_var.1), index=2
499 add.2 = s32[3]{0} add(get-tuple-element.2, get-tuple-element.3)
500 ROOT tuple = (s32[], s32[3]{0}) tuple(add, add.2, get-tuple-element.3)
501 }
502 SimpleLoop.condition {
503 loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
504 get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
505 constant.2 = s32[] constant(47)
506 ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT
507 }
508 ENTRY SimpleLoop {
509 constant.3 = s32[] constant(42)
510 constant.4 = s32[3]{0} constant({0, 1, 2})
511 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4, constant.4)
512 ROOT while = (s32[], s32[3]{0}, s32[3]{0}) while(tuple.1), condition=
513 SimpleLoop.condition, body=SimpleLoop.body
514 }
515 )";
516
517 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
518 EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
519 }
520
TEST_F(WhileLoopSimplifierTest,FlattenNestedTuple)521 TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) {
522 const string hlo_string = R"(
523 HloModule Test
524 Body {
525 param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
526 ta = (s32[1]) get-tuple-element(param), index=0
527 a = s32[1] get-tuple-element(ta), index=0
528 a.1 = s32[1] add(a, a)
529 tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1
530 ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
531 }
532 Cond {
533 param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
534 ROOT cond = pred[] constant(true)
535 }
536 ENTRY Loop {
537 a = s32[1] constant({0})
538 b = s32[2] constant({0,1})
539 c = s32[3] constant({0,1,2})
540 d = s32[4] constant({0,1,2,3})
541 ta = (s32[1]) tuple(a)
542 td = (s32[4]) tuple(d)
543 tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td)
544 init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
545 ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init),
546 condition=Cond, body=Body
547 })";
548
549 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
550 EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
551 // DCE away the old loop so there's just one while loop in the module, making
552 // it easy to find.
553 EXPECT_TRUE(HloDCE().Run(m.get()).ok());
554
555 HloInstruction* new_while = FindFirstWhile(m.get());
556 Shape flat_tuple =
557 ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie();
558 SCOPED_TRACE(m->ToString());
559 EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple));
560 EXPECT_TRUE(ShapeUtil::Equal(
561 new_while->while_body()->root_instruction()->shape(), flat_tuple));
562 EXPECT_TRUE(ShapeUtil::Equal(
563 new_while->while_body()->parameter_instruction(0)->shape(), flat_tuple));
564 EXPECT_TRUE(ShapeUtil::Equal(
565 new_while->while_condition()->parameter_instruction(0)->shape(),
566 flat_tuple));
567 EXPECT_TRUE(ShapeUtil::Equal(
568 m->entry_computation()->root_instruction()->shape(),
569 ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie()));
570 }
571
572 // Edge-case: All elements of the loop carry are constants which can be removed,
573 // leaving us with a nullary loop. This is a special case, we just replace the
574 // loop with its init.
TEST_F(WhileLoopSimplifierTest,OnlyConstantsInLoopCarry)575 TEST_F(WhileLoopSimplifierTest, OnlyConstantsInLoopCarry) {
576 const string hlo_string = R"(
577 HloModule Test
578 Body {
579 param = (s32[1]) parameter(0)
580 a = s32[1] constant({0})
581 ROOT tuple = (s32[1]) tuple(a)
582 }
583 Cond {
584 param = (s32[1]) parameter(0)
585 ROOT cond = pred[] constant(true)
586 }
587 ENTRY Loop {
588 a = s32[1] constant({0})
589 init = (s32[1]) tuple(a)
590 ROOT while = (s32[1]) while(init), condition=Cond, body=Body
591 })";
592
593 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
594 EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
595 EXPECT_TRUE(HloDCE().Run(m.get()).ok());
596 EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok());
597 EXPECT_THAT(m->entry_computation()->root_instruction(),
598 op::Tuple(op::Constant()));
599 }
600
TEST_F(WhileLoopSimplifierTest,RemoveConstantFromLoopCarry)601 TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) {
602 const string hlo_string = R"(
603 HloModule Test
604 Body {
605 param = (s32[1], s32[2], s32[3]) parameter(0)
606 a = s32[1] get-tuple-element(param), index=0
607 a.1 = s32[1] add(a, a)
608 b = s32[2] constant({1,1})
609 c = s32[3] constant({10,10,10})
610 ROOT tuple = (s32[1], s32[2], s32[3]) tuple(a.1, b, c)
611 }
612 Cond {
613 param = (s32[1], s32[2], s32[3]) parameter(0)
614 /* Use each tuple element. The verifier will then ensure that if any of
615 * these get modified, they're replaced with values of the correct shape. */
616 a = s32[1] get-tuple-element(param), index=0
617 b = s32[2] get-tuple-element(param), index=1
618 c = s32[3] get-tuple-element(param), index=2
619 ROOT cond = pred[] constant(true)
620 }
621 ENTRY Loop {
622 /* Only `b` should be simplified away. `a` is not a constant within the
623 * loop, and `c`'s value changes depending on whether we run 0 or 1
624 * iterations of the loop. */
625 a = s32[1] constant({0})
626 b = s32[2] constant({1,1})
627 c = s32[3] constant({2,2,2})
628 init = (s32[1], s32[2], s32[3]) tuple(a,b,c)
629 ROOT while = (s32[1], s32[2], s32[3]) while(init),
630 condition=Cond, body=Body
631 })";
632
633 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
634 EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
635 // DCE away the old loop so there's just one while loop in the module, making
636 // it easy to find.
637 EXPECT_TRUE(HloDCE().Run(m.get()).ok());
638 // Run the tuple simplifier to make the resulting HLO a bit easier to check.
639 EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok());
640
641 HloInstruction* new_while = FindFirstWhile(m.get());
642 Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie();
643 EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
644 EXPECT_TRUE(ShapeUtil::Equal(
645 new_while->while_body()->root_instruction()->shape(), new_while_shape));
646 EXPECT_TRUE(ShapeUtil::Equal(
647 new_while->while_body()->parameter_instruction(0)->shape(),
648 new_while_shape));
649 EXPECT_TRUE(ShapeUtil::Equal(
650 new_while->while_condition()->parameter_instruction(0)->shape(),
651 new_while_shape));
652 EXPECT_TRUE(
653 ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(),
654 ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie()));
655 EXPECT_THAT(m->entry_computation()->root_instruction(),
656 op::Tuple(_, op::Constant(), _));
657 }
658
659 const char* const kSimpleMergeInductionVariablesModule = R"(
660 HloModule Test
661 Body {
662 param = (TYPE[], TYPE[], TYPE[]) parameter(0)
663
664 a = TYPE[] get-tuple-element(param), index=0
665 one = TYPE[] constant(1)
666 a1 = TYPE[] add(a, one)
667
668 b = TYPE[] get-tuple-element(param), index=1
669 negone = TYPE[] constant(-1)
670 b1 = TYPE[] add(b, negone)
671
672 c = TYPE[] add(a, b)
673
674 ROOT tuple = (TYPE[], TYPE[], TYPE[]) tuple(a1,b1,c)
675 }
676 Cond {
677 param = (TYPE[], TYPE[], TYPE[]) parameter(0)
678 a = TYPE[] get-tuple-element(param), index=0
679 b = TYPE[] get-tuple-element(param), index=1
680 sum = TYPE[] power(a, b)
681 ten = TYPE[] constant(10)
682 ROOT cond = pred[] compare(sum, ten), direction=LT
683 }
684 ENTRY Loop {
685 a = TYPE[] constant(10)
686 b = TYPE[] constant(100)
687 c = TYPE[] constant(0)
688 init = (TYPE[], TYPE[], TYPE[]) tuple(a,b,c)
689 while = (TYPE[], TYPE[], TYPE[]) while(init), condition=Cond, body=Body
690
691 a1 = TYPE[] get-tuple-element(while), index=0
692 b1 = TYPE[] get-tuple-element(while), index=1
693 ROOT sum = TYPE[] add(a1, b1)
694 })";
695
TEST_F(WhileLoopSimplifierTest,MergeInductionVariables_Simple)696 TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) {
697 string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule,
698 {{"TYPE", "s32"}});
699
700 auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
701 EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
702 // DCE away the old loop so there's just one while loop in the module, making
703 // it easy to find, and run the tuple simplifier to make the resulting HLO
704 // easier to check.
705 EXPECT_TRUE(HloDCE().Run(m.get()).ok());
706 EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok());
707
708 HloInstruction* new_while = FindFirstWhile(m.get());
709 // We should have added a new loop counter for s32[] to the end of the tuple.
710 SCOPED_TRACE(m->ToString());
711 Shape new_while_shape =
712 ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie();
713 EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
714 EXPECT_TRUE(ShapeUtil::Equal(
715 new_while->while_body()->root_instruction()->shape(), new_while_shape));
716 EXPECT_TRUE(ShapeUtil::Equal(
717 new_while->while_body()->parameter_instruction(0)->shape(),
718 new_while_shape));
719 EXPECT_TRUE(ShapeUtil::Equal(
720 new_while->while_condition()->parameter_instruction(0)->shape(),
721 new_while_shape));
722
723 EXPECT_THAT(new_while->while_body()->root_instruction(),
724 op::Tuple(op::GetTupleElement(op::Parameter(), 0),
725 op::GetTupleElement(op::Parameter(), 1), op::Add(),
726 op::Add(op::GetTupleElement(op::Parameter(), 3),
727 op::Constant())));
728 EXPECT_THAT(new_while->while_condition()->root_instruction(),
729 op::Lt(op::Power(op::Add(), op::Add()), op::Constant()));
730 }
731
732 // We shouldn't merge S16 induction variables; we can't create constants of this
733 // type because S16 literals are not implemented.
TEST_F(WhileLoopSimplifierTest,MergeInductionVariables_SkipS16)734 TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) {
735 string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule,
736 {{"TYPE", "s16"}});
737 EXPECT_FALSE(
738 WhileLoopSimplifier()
739 .Run(ParseAndReturnVerifiedModule(hlo_string).ValueOrDie().get())
740 .ValueOrDie());
741 }
742
743 } // namespace
744 } // namespace xla
745