1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_replace.h"
20 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
21 #include "tensorflow/compiler/xla/service/hlo_cse.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 
31 namespace xla {
32 namespace {
33 
34 using ::testing::_;
35 namespace op = xla::testing::opcode_matchers;
36 
37 // Returns the first kWhile instruction within m's entry computation.
FindFirstWhile(HloModule * m)38 HloInstruction* FindFirstWhile(HloModule* m) {
39   const auto& instrs = m->entry_computation()->instructions();
40   return *absl::c_find_if(instrs, [](const HloInstruction* instr) {
41     return instr->opcode() == HloOpcode::kWhile;
42   });
43 }
44 
45 class WhileLoopSimplifierTest : public HloTestBase {
46  protected:
47   // Makes an HloModule that contains a loop with `num_iters` iteration.
48   TF_MUST_USE_RESULT std::unique_ptr<VerifiedHloModule>
49   MakeModuleWithSimpleLoop(int num_iters);
50 
51   // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to
52   // the loop-condition through an element of a tuple which is the
53   // loop-condition parameter.
54   TF_MUST_USE_RESULT std::unique_ptr<VerifiedHloModule>
55   MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters);
56 };
57 
58 std::unique_ptr<VerifiedHloModule>
MakeModuleWithSimpleLoop(int num_iters)59 WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
60   string hlo_string_template = R"(
61   HloModule SimpleLoop
62   SimpleLoop.body {
63     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
64     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
65     constant.1 = s32[] constant(1)
66     add = s32[] add(get-tuple-element.1, constant.1)
67     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
68     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
69     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
70   }
71   SimpleLoop.condition {
72     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
73     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
74     constant.2 = s32[] constant({{LOOP_BOUND}})
75     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
76   }
77   ENTRY SimpleLoop {
78     constant.3 = s32[] constant(42)
79     constant.4 = s32[3]{0} constant({0, 1, 2})
80     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
81     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
82       SimpleLoop.condition, body=SimpleLoop.body
83   }
84   )";
85 
86   string hlo_string = absl::StrReplaceAll(
87       hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
88   return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
89 }
90 
91 std::unique_ptr<VerifiedHloModule>
MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters)92 WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
93     int num_iters) {
94   string hlo_string_template = R"(
95   HloModule SimpleLoopWithIndirectLoopBound
96   SimpleLoopWithIndirectLoopBound.body {
97     loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0)
98     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
99     constant.1 = s32[] constant(1)
100     add = s32[] add(get-tuple-element.1, constant.1)
101     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
102     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
103     limit = s32[] get-tuple-element(loop_var.1), index=2
104     ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit)
105   }
106   SimpleLoopWithIndirectLoopBound.condition {
107     loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
108     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
109     get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
110     ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT
111   }
112   ENTRY SimpleLoopWithIndirectLoopBound {
113     constant.3 = s32[] constant(42)
114     constant.4 = s32[3]{0} constant({0, 1, 2})
115     constant.2 = s32[] constant({{LOOP_BOUND}})
116     tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4,
117       constant.2)
118     ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1),
119       condition=SimpleLoopWithIndirectLoopBound.condition,
120       body=SimpleLoopWithIndirectLoopBound.body
121   }
122   )";
123 
124   string hlo_string = absl::StrReplaceAll(
125       hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
126   return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
127 }
128 
TEST_F(WhileLoopSimplifierTest,LoopWithZeroIterationSimiplified)129 TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
130   auto m = MakeModuleWithSimpleLoop(/*num_iters=*/0);
131   ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
132   EXPECT_THAT(m->entry_computation()->root_instruction(),
133               op::Tuple(op::Constant(), op::Constant()));
134 }
135 
TEST_F(WhileLoopSimplifierTest,LoopWithZeroIterationTupleElementLoopBoundSimplified)136 TEST_F(WhileLoopSimplifierTest,
137        LoopWithZeroIterationTupleElementLoopBoundSimplified) {
138   auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0);
139   ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
140   EXPECT_THAT(m->entry_computation()->root_instruction(),
141               op::Tuple(op::Constant(), op::Constant(), op::Constant()));
142 }
143 
TEST_F(WhileLoopSimplifierTest,LoopWithOneIterationSimplified)144 TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
145   auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
146   ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
147   EXPECT_THAT(m->entry_computation()->root_instruction(),
148               op::Tuple(op::Add(), op::Multiply()));
149 }
150 
TEST_F(WhileLoopSimplifierTest,LoopWithOneIterationTupleELementLoopBoundSimplified)151 TEST_F(WhileLoopSimplifierTest,
152        LoopWithOneIterationTupleELementLoopBoundSimplified) {
153   auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1);
154   ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
155   EXPECT_THAT(m->entry_computation()->root_instruction(),
156               op::Tuple(op::Add(), op::Multiply(), op::Constant()));
157 }
158 
TEST_F(WhileLoopSimplifierTest,LoopWithTwoIterationsNotSimplified)159 TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) {
160   auto m = MakeModuleWithSimpleLoop(/*num_iters=*/2);
161   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
162 }
163 
TEST_F(WhileLoopSimplifierTest,LoopWithControlDependencySimplifiedDependencyPreserved)164 TEST_F(WhileLoopSimplifierTest,
165        LoopWithControlDependencySimplifiedDependencyPreserved) {
166   auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
167   HloComputation* computation = m->entry_computation();
168   auto* while_op = computation->root_instruction();
169   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
170   auto* true_op = while_op->while_body()->AddInstruction(
171       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
172   TF_ASSERT_OK(true_op->AddControlDependencyTo(
173       while_op->while_body()->root_instruction()));
174   ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
175   EXPECT_THAT(computation->root_instruction()->control_predecessors(),
176               ElementsAre(op::Constant()))
177       << computation->ToString();
178 }
179 
180 // Loops that contain send/recv nodes can't be simplified; the loop structure
181 // around send/recv nodes must be preserved.
TEST_F(WhileLoopSimplifierTest,LoopWithSendNotSimplified)182 TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) {
183   auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
184   HloComputation* computation = m->entry_computation();
185   auto* while_op = computation->root_instruction();
186   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
187   auto* while_body = while_op->while_body();
188   auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
189   auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
190       while_body->AddInstruction(
191           HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
192       token,
193       /*channel_id=*/0));
194   while_body->AddInstruction(HloInstruction::CreateSendDone(send));
195   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
196 }
197 
TEST_F(WhileLoopSimplifierTest,LoopWithRecvNotSimplified)198 TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
199   auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
200   HloComputation* computation = m->entry_computation();
201   auto* while_op = computation->root_instruction();
202   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
203   auto* while_body = while_op->while_body();
204   auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
205   auto* recv = while_body->AddInstruction(
206       HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token,
207                                  /*channel_id=*/0));
208   while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
209   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
210 }
211 
212 // The limitation on not being able to simplify loops that contain infeeds (and
213 // other non-removable instructions) isn't fundamental -- it just stems from the
214 // fact that our infrastructure sees simplifying such a loop as tantamount to
215 // removing the non-removable instruction.
TEST_F(WhileLoopSimplifierTest,LoopWithInfeedNotSimplified)216 TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) {
217   auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
218   HloComputation* computation = m->entry_computation();
219   auto* while_op = computation->root_instruction();
220   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
221   auto* while_body = while_op->while_body();
222   auto token = while_body->AddInstruction(HloInstruction::CreateToken());
223   while_body->AddInstruction(HloInstruction::CreateInfeed(
224       ShapeUtil::MakeShape(F32, {1}), token, "config"));
225   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
226 }
227 
228 // A non-tuple shaped loop shouldn't be simplified or crash the compiler.
TEST_F(WhileLoopSimplifierTest,NonTupleShapedLoopNotSimplified)229 TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) {
230   const string hlo_string = R"(
231  HloModule NonTupleShapedLoop
232  NonTupleShapedLoop.body {
233    loop_var.1 = s32[] parameter(0)
234    constant.1 = s32[] constant(-1)
235    ROOT add = s32[] add(s32[] loop_var.1, s32[] constant.1)
236  }
237  NonTupleShapedLoop.condition {
238    loop_var = s32[] parameter(0)
239    constant = s32[] constant(100)
240    ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT
241  }
242  ENTRY INonTupleShapedLoop {
243    constant.2 = s32[] constant(42)
244    ROOT while = s32[] while(s32[] constant.2),
245      condition=NonTupleShapedLoop.condition,
246      body=NonTupleShapedLoop.body
247   }
248   )";
249 
250   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
251   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
252 }
253 
254 // A while loop that does nothing else besides swapping tuple elements
255 // can't be simplified as the result of the swapping is visible to users of the
256 // loop.
TEST_F(WhileLoopSimplifierTest,LoopSwappingTupleElementsNotSimplified)257 TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) {
258   const string hlo_string = R"(
259   HloModule SwappingTupleElements
260   SwappingTupleElements.body {
261     loop_var = (s32[], s32[]) parameter(0)
262     get-tuple-element = s32[] get-tuple-element((s32[], s32[]) loop_var),index=1
263     get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[]) loop_var),
264       index=0
265     ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element,
266       s32[] get-tuple-element.1)
267   }
268   SwappingTupleElements.always_true {
269    param = (s32[], s32[]) parameter(0)
270    ROOT constant = pred[] constant(true)
271   }
272   ENTRY SwappingTupleElements {
273    x = s32[] parameter(0)
274    y = s32[] parameter(1)
275    tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y)
276    ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1),
277      condition=SwappingTupleElements.always_true,
278      body=SwappingTupleElements.body
279   }
280   )";
281 
282   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
283   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
284 }
285 
286 // Construct a loop where we assign a constant to tuple element 0 in each
287 // iteration.  We can't eliminate tuple element 0, even though we never use its
288 // value.
TEST_F(WhileLoopSimplifierTest,LoopWithUnusedButModifiedTupleElementNotSimplified)289 TEST_F(WhileLoopSimplifierTest,
290        LoopWithUnusedButModifiedTupleElementNotSimplified) {
291   const string hlo_string = R"(
292   HloModule UnusedButModifiedTupleElement
293   UnusedButModifiedTupleElement.body {
294     loop_var = (s32[]) parameter(0)
295     constant.1 = s32[] constant(1)
296     ROOT tuple = (s32[]) tuple(s32[] constant.1)
297   }
298   UnusedButModifiedTupleElement.always_true {
299     param = (s32[]) parameter(0)
300    ROOT  constant = pred[] constant(true)
301   }
302   ENTRY  UnusedButModifiedTupleElement {
303     constant.2 = s32[] constant(0)
304     tuple.1 = (s32[]) tuple(s32[]  constant.2)
305     ROOT while = (s32[]) while((s32[]) tuple.1),
306       condition=UnusedButModifiedTupleElement.always_true,
307       body=UnusedButModifiedTupleElement.body
308   }
309   )";
310 
311   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
312   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
313 }
314 
315 // Nothing to simplify in a while loop whose tuple has 0 elements.
TEST_F(WhileLoopSimplifierTest,LoopWithEmptyTupleNotSimplified)316 TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) {
317   const string hlo_string = R"(
318   HloModule EmptyTuple
319   EmptyTuple.body {
320     loop_var = () parameter(0)
321     ROOT  tuple = () tuple()
322   }
323   EmptyTuple.always_true {
324    param = () parameter(0)
325    ROOT constant = pred[] constant(true)
326   }
327   ENTRY EmptyTuple {
328    tuple.1 = () tuple()
329    ROOT while = () while(() tuple.1), condition=EmptyTuple.always_true,
330      body=EmptyTuple.body
331   }
332   )";
333 
334   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
335   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
336 }
337 
338 // While loop where one tuple element is used twice in the body, and thus can't
339 // be simplified away.
TEST_F(WhileLoopSimplifierTest,LoopWithElemUsedTwiceNotSimplified)340 TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) {
341   const string hlo_string = R"(
342   HloModule ElemUsedTwice
343   ElemUsedTwice.body {
344     param0 = (s32[], s32[]) parameter(0)
345     get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param0), index=0
346     ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element,
347       s32[] get-tuple-element)
348   }
349   ElemUsedTwice.always_true {
350     param = (s32[], s32[]) parameter(0)
351     ROOT constant = pred[] constant(true)
352   }
353   ENTRY ElemUsedTwice {
354    x = s32[] parameter(0)
355    y = s32[] parameter(1)
356    tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y)
357    ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1),
358      condition=ElemUsedTwice.always_true, body=ElemUsedTwice.body
359   }
360   )";
361 
362   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
363   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
364 }
365 
366 // This while loop has three tuple elements.  Element 0 is unused and should be
367 // removed. Element 1 is used by the loop body, and element 2 is used by the
368 // loop condition; these two should stay.
TEST_F(WhileLoopSimplifierTest,RemoveUnusedLoopOperands)369 TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
370   const string hlo_string = R"(
371   HloModule RemoveUnusedOperands
372   RemoveUnusedOperands.body {
373     loop_var = (s32[], s32[], s32[]) parameter(0)
374     get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[],
375       s32[]) loop_var), index=0
376     get-tuple-element.2 = s32[] get-tuple-element((s32[], s32[],
377       s32[]) loop_var), index=1
378     constant.1 = s32[] constant(1)
379     add = s32[] add(s32[] get-tuple-element.2, s32[] constant.1)
380     get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[])
381       loop_var), index=2
382     ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element.1,
383       s32[] add, s32[] get-tuple-element.3)
384   }
385   RemoveUnusedOperands.loop_condition {
386     constant.2 = s32[] constant(0)
387     param0 = (s32[], s32[], s32[]) parameter(0)
388     get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0),
389       index=2
390     ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ
391   }
392   ENTRY RemoveUnusedOperands {
393     x = s32[] parameter(0)
394     constant.3 = s32[] constant(0)
395     y = s32[] parameter(1)
396     tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] constant.3,
397       s32[] y)
398     ROOT while = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) tuple.1),
399       condition=RemoveUnusedOperands.loop_condition,
400       body=RemoveUnusedOperands.body
401   }
402   )";
403 
404   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
405   EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
406 
407   // The original while instruction is still left in the module as a dead
408   // instruction, find a while instruction with a different name as the new
409   // while instruction.
410   const auto& instrs = m->entry_computation()->instructions();
411   HloInstruction* new_while_op =
412       *absl::c_find_if(instrs, [&](const HloInstruction* instr) {
413         return (instr->opcode() == HloOpcode::kWhile &&
414                 instr->name() != "while");
415       });
416 
417   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
418   EXPECT_TRUE(
419       ShapeUtil::Equal(new_while_op->shape(),
420                        ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32})))
421       << ShapeUtil::HumanString(new_while_op->shape());
422   EXPECT_THAT(
423       new_while_op->while_body()->root_instruction(),
424       op::Tuple(
425           op::Add(op::GetTupleElement(op::Parameter(0), /*tuple_index=*/0),
426                   op::Constant()),
427           op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
428 
429   EXPECT_THAT(new_while_op->while_condition()->root_instruction(),
430               op::Eq(op::Constant(),
431                      op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
432 }
433 
TEST_F(WhileLoopSimplifierTest,LoopWithNonTupleBodyShapeNotSimplified)434 TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
435   const string hlo_string = R"(
436   HloModule BodyHasNonTupleRoot
437   BodyHasNonTupleRoot.passthrough {
438     ROOT param = (s32[], s32[]) parameter(0)
439   }
440   BodyHasNonTupleRoot.always_true {
441     param.1 = (s32[], s32[]) parameter(0)
442     ROOT constant = pred[] constant(true)
443   }
444   ENTRY BodyHasNonTupleRoot {
445     init_value = (s32[], s32[]) parameter(0)
446     ROOT while = (s32[], s32[]) while((s32[], s32[]) init_value),
447       condition=BodyHasNonTupleRoot.always_true,
448       body=BodyHasNonTupleRoot.passthrough
449   }
450   )";
451 
452   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
453   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
454 }
455 
TEST_F(WhileLoopSimplifierTest,LoopWithNonTupleBodyRootInstructionNotSimplified)456 TEST_F(WhileLoopSimplifierTest,
457        LoopWithNonTupleBodyRootInstructionNotSimplified) {
458   const string hlo_string = R"(
459   HloModule SimpleLoop
460   SimpleLoop.body {
461     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
462     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
463     constant.1 = s32[] constant(1)
464     add = s32[] add(get-tuple-element.1, constant.1)
465     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
466     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
467     ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply),
468       custom_call_target="x"
469   }
470   SimpleLoop.condition {
471     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
472     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
473     constant.2 = s32[] constant(44)
474     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
475   }
476   ENTRY SimpleLoop {
477     constant.3 = s32[] constant(42)
478     constant.4 = s32[3]{0} constant({0, 1, 2})
479     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
480     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
481       SimpleLoop.condition, body=SimpleLoop.body
482   }
483   )";
484 
485   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
486   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
487 }
488 
TEST_F(WhileLoopSimplifierTest,LoopWithArrayConstantNotSimplified)489 TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) {
490   const string hlo_string = R"(
491   HloModule SimpleLoop
492   SimpleLoop.body {
493     loop_var.1 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
494     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
495     constant.1 = s32[] constant(1)
496     add = s32[] add(get-tuple-element.1, constant.1)
497     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
498     get-tuple-element.3 = s32[3]{0} get-tuple-element(loop_var.1), index=2
499     add.2 = s32[3]{0} add(get-tuple-element.2, get-tuple-element.3)
500     ROOT tuple = (s32[], s32[3]{0}) tuple(add, add.2, get-tuple-element.3)
501   }
502   SimpleLoop.condition {
503     loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
504     get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
505     constant.2 = s32[] constant(47)
506     ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT
507   }
508   ENTRY SimpleLoop {
509     constant.3 = s32[] constant(42)
510     constant.4 = s32[3]{0} constant({0, 1, 2})
511     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4, constant.4)
512     ROOT while = (s32[], s32[3]{0}, s32[3]{0}) while(tuple.1), condition=
513       SimpleLoop.condition, body=SimpleLoop.body
514   }
515   )";
516 
517   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
518   EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
519 }
520 
TEST_F(WhileLoopSimplifierTest,FlattenNestedTuple)521 TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) {
522   const string hlo_string = R"(
523   HloModule Test
524   Body {
525     param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
526     ta = (s32[1]) get-tuple-element(param), index=0
527     a = s32[1] get-tuple-element(ta), index=0
528     a.1 = s32[1] add(a, a)
529     tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1
530     ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
531   }
532   Cond {
533     param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
534     ROOT cond = pred[] constant(true)
535   }
536   ENTRY Loop {
537     a = s32[1] constant({0})
538     b = s32[2] constant({0,1})
539     c = s32[3] constant({0,1,2})
540     d = s32[4] constant({0,1,2,3})
541     ta = (s32[1]) tuple(a)
542     td = (s32[4]) tuple(d)
543     tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td)
544     init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
545     ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init),
546       condition=Cond, body=Body
547   })";
548 
549   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
550   EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
551   // DCE away the old loop so there's just one while loop in the module, making
552   // it easy to find.
553   EXPECT_TRUE(HloDCE().Run(m.get()).ok());
554 
555   HloInstruction* new_while = FindFirstWhile(m.get());
556   Shape flat_tuple =
557       ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie();
558   SCOPED_TRACE(m->ToString());
559   EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple));
560   EXPECT_TRUE(ShapeUtil::Equal(
561       new_while->while_body()->root_instruction()->shape(), flat_tuple));
562   EXPECT_TRUE(ShapeUtil::Equal(
563       new_while->while_body()->parameter_instruction(0)->shape(), flat_tuple));
564   EXPECT_TRUE(ShapeUtil::Equal(
565       new_while->while_condition()->parameter_instruction(0)->shape(),
566       flat_tuple));
567   EXPECT_TRUE(ShapeUtil::Equal(
568       m->entry_computation()->root_instruction()->shape(),
569       ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie()));
570 }
571 
572 // Edge-case: All elements of the loop carry are constants which can be removed,
573 // leaving us with a nullary loop.  This is a special case, we just replace the
574 // loop with its init.
TEST_F(WhileLoopSimplifierTest,OnlyConstantsInLoopCarry)575 TEST_F(WhileLoopSimplifierTest, OnlyConstantsInLoopCarry) {
576   const string hlo_string = R"(
577   HloModule Test
578   Body {
579     param = (s32[1]) parameter(0)
580     a = s32[1] constant({0})
581     ROOT tuple = (s32[1]) tuple(a)
582   }
583   Cond {
584     param = (s32[1]) parameter(0)
585     ROOT cond = pred[] constant(true)
586   }
587   ENTRY Loop {
588     a = s32[1] constant({0})
589     init = (s32[1]) tuple(a)
590     ROOT while = (s32[1]) while(init), condition=Cond, body=Body
591   })";
592 
593   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
594   EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
595   EXPECT_TRUE(HloDCE().Run(m.get()).ok());
596   EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok());
597   EXPECT_THAT(m->entry_computation()->root_instruction(),
598               op::Tuple(op::Constant()));
599 }
600 
TEST_F(WhileLoopSimplifierTest,RemoveConstantFromLoopCarry)601 TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) {
602   const string hlo_string = R"(
603   HloModule Test
604   Body {
605     param = (s32[1], s32[2], s32[3]) parameter(0)
606     a = s32[1] get-tuple-element(param), index=0
607     a.1 = s32[1] add(a, a)
608     b = s32[2] constant({1,1})
609     c = s32[3] constant({10,10,10})
610     ROOT tuple = (s32[1], s32[2], s32[3]) tuple(a.1, b, c)
611   }
612   Cond {
613     param = (s32[1], s32[2], s32[3]) parameter(0)
614     /* Use each tuple element.  The verifier will then ensure that if any of
615      * these get modified, they're replaced with values of the correct shape. */
616     a = s32[1] get-tuple-element(param), index=0
617     b = s32[2] get-tuple-element(param), index=1
618     c = s32[3] get-tuple-element(param), index=2
619     ROOT cond = pred[] constant(true)
620   }
621   ENTRY Loop {
622     /* Only `b` should be simplified away.  `a` is not a constant within the
623      * loop, and `c`'s value changes depending on whether we run 0 or 1
624      * iterations of the loop. */
625     a = s32[1] constant({0})
626     b = s32[2] constant({1,1})
627     c = s32[3] constant({2,2,2})
628     init = (s32[1], s32[2], s32[3]) tuple(a,b,c)
629     ROOT while = (s32[1], s32[2], s32[3]) while(init),
630       condition=Cond, body=Body
631   })";
632 
633   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
634   EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
635   // DCE away the old loop so there's just one while loop in the module, making
636   // it easy to find.
637   EXPECT_TRUE(HloDCE().Run(m.get()).ok());
638   // Run the tuple simplifier to make the resulting HLO a bit easier to check.
639   EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok());
640 
641   HloInstruction* new_while = FindFirstWhile(m.get());
642   Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie();
643   EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
644   EXPECT_TRUE(ShapeUtil::Equal(
645       new_while->while_body()->root_instruction()->shape(), new_while_shape));
646   EXPECT_TRUE(ShapeUtil::Equal(
647       new_while->while_body()->parameter_instruction(0)->shape(),
648       new_while_shape));
649   EXPECT_TRUE(ShapeUtil::Equal(
650       new_while->while_condition()->parameter_instruction(0)->shape(),
651       new_while_shape));
652   EXPECT_TRUE(
653       ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(),
654                        ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie()));
655   EXPECT_THAT(m->entry_computation()->root_instruction(),
656               op::Tuple(_, op::Constant(), _));
657 }
658 
659 const char* const kSimpleMergeInductionVariablesModule = R"(
660   HloModule Test
661   Body {
662     param = (TYPE[], TYPE[], TYPE[]) parameter(0)
663 
664     a = TYPE[] get-tuple-element(param), index=0
665     one = TYPE[] constant(1)
666     a1 = TYPE[] add(a, one)
667 
668     b = TYPE[] get-tuple-element(param), index=1
669     negone = TYPE[] constant(-1)
670     b1 = TYPE[] add(b, negone)
671 
672     c = TYPE[] add(a, b)
673 
674     ROOT tuple = (TYPE[], TYPE[], TYPE[]) tuple(a1,b1,c)
675   }
676   Cond {
677     param = (TYPE[], TYPE[], TYPE[]) parameter(0)
678     a = TYPE[] get-tuple-element(param), index=0
679     b = TYPE[] get-tuple-element(param), index=1
680     sum = TYPE[] power(a, b)
681     ten = TYPE[] constant(10)
682     ROOT cond = pred[] compare(sum, ten), direction=LT
683   }
684   ENTRY Loop {
685     a = TYPE[] constant(10)
686     b = TYPE[] constant(100)
687     c = TYPE[] constant(0)
688     init = (TYPE[], TYPE[], TYPE[]) tuple(a,b,c)
689     while = (TYPE[], TYPE[], TYPE[]) while(init), condition=Cond, body=Body
690 
691     a1 = TYPE[] get-tuple-element(while), index=0
692     b1 = TYPE[] get-tuple-element(while), index=1
693     ROOT sum = TYPE[] add(a1, b1)
694   })";
695 
TEST_F(WhileLoopSimplifierTest,MergeInductionVariables_Simple)696 TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) {
697   string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule,
698                                           {{"TYPE", "s32"}});
699 
700   auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
701   EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
702   // DCE away the old loop so there's just one while loop in the module, making
703   // it easy to find, and run the tuple simplifier to make the resulting HLO
704   // easier to check.
705   EXPECT_TRUE(HloDCE().Run(m.get()).ok());
706   EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok());
707 
708   HloInstruction* new_while = FindFirstWhile(m.get());
709   // We should have added a new loop counter for s32[] to the end of the tuple.
710   SCOPED_TRACE(m->ToString());
711   Shape new_while_shape =
712       ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie();
713   EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
714   EXPECT_TRUE(ShapeUtil::Equal(
715       new_while->while_body()->root_instruction()->shape(), new_while_shape));
716   EXPECT_TRUE(ShapeUtil::Equal(
717       new_while->while_body()->parameter_instruction(0)->shape(),
718       new_while_shape));
719   EXPECT_TRUE(ShapeUtil::Equal(
720       new_while->while_condition()->parameter_instruction(0)->shape(),
721       new_while_shape));
722 
723   EXPECT_THAT(new_while->while_body()->root_instruction(),
724               op::Tuple(op::GetTupleElement(op::Parameter(), 0),
725                         op::GetTupleElement(op::Parameter(), 1), op::Add(),
726                         op::Add(op::GetTupleElement(op::Parameter(), 3),
727                                 op::Constant())));
728   EXPECT_THAT(new_while->while_condition()->root_instruction(),
729               op::Lt(op::Power(op::Add(), op::Add()), op::Constant()));
730 }
731 
732 // We shouldn't merge S16 induction variables; we can't create constants of this
733 // type because S16 literals are not implemented.
TEST_F(WhileLoopSimplifierTest,MergeInductionVariables_SkipS16)734 TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) {
735   string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule,
736                                           {{"TYPE", "s16"}});
737   EXPECT_FALSE(
738       WhileLoopSimplifier()
739           .Run(ParseAndReturnVerifiedModule(hlo_string).ValueOrDie().get())
740           .ValueOrDie());
741 }
742 
743 }  // namespace
744 }  // namespace xla
745