1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_parser.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 
24 namespace xla {
25 namespace {
26 
27 namespace op = xla::testing::opcode_matchers;
28 
29 class WhileLoopInvariantCodeMotionTest : public HloTestBase {
30  public:
31   // Makes a computation which has one parameter, of the given shape, and always
32   // returns PRED[]{true}.  This is useful as a dummy loop condition.
33   HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
34                                             HloModule* module);
35 };
36 
FindOnlyWhileInstruction(HloComputation * computation,HloInstruction ** while_instruction)37 static void FindOnlyWhileInstruction(HloComputation* computation,
38                                      HloInstruction** while_instruction) {
39   *while_instruction = nullptr;
40   for (auto* instr : computation->instructions()) {
41     if (instr->opcode() == HloOpcode::kWhile) {
42       ASSERT_EQ(*while_instruction, nullptr);
43       *while_instruction = instr;
44     }
45   }
46 
47   ASSERT_NE(*while_instruction, nullptr);
48 }
49 
MakeAlwaysTrueComputation(const Shape & param_shape,HloModule * module)50 HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation(
51     const Shape& param_shape, HloModule* module) {
52   HloComputation::Builder builder(TestName() + ".always_true");
53   builder.AddInstruction(
54       HloInstruction::CreateParameter(0, param_shape, "param"));
55   builder.AddInstruction(
56       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
57   return module->AddEmbeddedComputation(builder.Build());
58 }
59 
TEST_F(WhileLoopInvariantCodeMotionTest,HoistOneInvariantOperation)60 TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) {
61   auto m = CreateNewVerifiedModule();
62   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
63   Shape while_shape =
64       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
65 
66   HloComputation* while_body = [&]() {
67     HloComputation::Builder builder(TestName() + ".while_body");
68     HloInstruction* param = builder.AddInstruction(
69         HloInstruction::CreateParameter(0, while_shape, "param"));
70     HloInstruction* gte_0 = builder.AddInstruction(
71         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
72     HloInstruction* gte_1 = builder.AddInstruction(
73         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
74     HloInstruction* add_result =
75         builder.AddInstruction(HloInstruction::CreateBinary(
76             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
77     builder.AddInstruction(
78         HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
79 
80     return m->AddEmbeddedComputation(builder.Build());
81   }();
82 
83   HloComputation::Builder builder(TestName());
84   auto* init_value = builder.AddInstruction(
85       HloInstruction::CreateParameter(0, while_shape, "init_value"));
86   builder.AddInstruction(HloInstruction::CreateWhile(
87       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
88       init_value));
89   HloComputation* entry_computation = m->AddEntryComputation(builder.Build());
90   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
91                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
92   EXPECT_TRUE(simplified_loop);
93 
94   HloInstruction* transformed_while;
95   FindOnlyWhileInstruction(entry_computation, &transformed_while);
96 
97   EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
98   EXPECT_THAT(transformed_while->while_body()->instructions(),
99               Each(Not(op::Add())));
100 }
101 
TEST_F(WhileLoopInvariantCodeMotionTest,HoistInvariantOperationTree)102 TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) {
103   auto m = CreateNewVerifiedModule();
104   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
105   Shape while_shape =
106       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
107 
108   HloComputation* while_body = [&]() {
109     HloComputation::Builder builder(TestName() + ".while_body");
110     HloInstruction* param = builder.AddInstruction(
111         HloInstruction::CreateParameter(0, while_shape, "param"));
112     HloInstruction* gte_0 = builder.AddInstruction(
113         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
114     HloInstruction* gte_1 = builder.AddInstruction(
115         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
116     HloInstruction* gte_2_loop_variant = builder.AddInstruction(
117         HloInstruction::CreateGetTupleElement(scalar_s32, param, 2));
118 
119     HloInstruction* add_result =
120         builder.AddInstruction(HloInstruction::CreateBinary(
121             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
122     HloInstruction* mul_result =
123         builder.AddInstruction(HloInstruction::CreateBinary(
124             scalar_s32, HloOpcode::kMultiply, add_result, gte_1));
125     HloInstruction* negate_result =
126         builder.AddInstruction(HloInstruction::CreateUnary(
127             scalar_s32, HloOpcode::kNegate, mul_result));
128     HloInstruction* constant = builder.AddInstruction(
129         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(4)));
130     HloInstruction* sub_result =
131         builder.AddInstruction(HloInstruction::CreateBinary(
132             scalar_s32, HloOpcode::kSubtract, negate_result, constant));
133     HloInstruction* divide_result =
134         builder.AddInstruction(HloInstruction::CreateBinary(
135             scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant));
136     builder.AddInstruction(
137         HloInstruction::CreateTuple({gte_0, gte_1, divide_result}));
138 
139     return m->AddEmbeddedComputation(builder.Build());
140   }();
141 
142   HloComputation::Builder builder(TestName());
143   auto* init_value = builder.AddInstruction(
144       HloInstruction::CreateParameter(0, while_shape, "init_value"));
145   builder.AddInstruction(HloInstruction::CreateWhile(
146       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
147       init_value));
148   HloComputation* entry_computation = m->AddEntryComputation(builder.Build());
149   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
150                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
151   EXPECT_TRUE(simplified_loop);
152 
153   HloInstruction* transformed_while;
154   FindOnlyWhileInstruction(entry_computation, &transformed_while);
155 
156   EXPECT_THAT(entry_computation->instructions(),
157               AllOf(Contains(op::Add()), Contains(op::Multiply()),
158                     Contains(op::Negate()), Contains(op::Subtract()),
159                     Contains(op::Constant()),
160 
161                     // The division had a loop varying operand so that better
162                     // not be hoisted.
163                     Not(Contains(op::Divide()))));
164 
165   EXPECT_THAT(transformed_while->while_body()->instructions(),
166               Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(),
167                              op::Subtract(), op::Constant()))));
168 
169   EXPECT_THAT(transformed_while->while_body()->instructions(),
170               Contains(op::Divide()));
171 }
172 
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistTriviallyLoopVaryingComputation)173 TEST_F(WhileLoopInvariantCodeMotionTest,
174        DontHoistTriviallyLoopVaryingComputation) {
175   // Basic negative test: the add expression is not loop invariant.
176   auto m = CreateNewVerifiedModule();
177   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
178   Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
179 
180   HloComputation* while_body = [&]() {
181     HloComputation::Builder builder(TestName() + ".while_body");
182     HloInstruction* param = builder.AddInstruction(
183         HloInstruction::CreateParameter(0, while_shape, "param"));
184     HloInstruction* gte_0 = builder.AddInstruction(
185         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
186     HloInstruction* gte_1 = builder.AddInstruction(
187         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
188     HloInstruction* add_result =
189         builder.AddInstruction(HloInstruction::CreateBinary(
190             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
191     builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result}));
192 
193     return m->AddEmbeddedComputation(builder.Build());
194   }();
195 
196   HloComputation::Builder builder(TestName());
197   auto* init_value = builder.AddInstruction(
198       HloInstruction::CreateParameter(0, while_shape, "init_value"));
199   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
200       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
201       init_value));
202 
203   m->AddEntryComputation(builder.Build());
204 
205   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
206                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
207   EXPECT_FALSE(simplified_loop);
208 
209   EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
210 }
211 
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistLoopVaryingComputationWithAlternatingTuples)212 TEST_F(WhileLoopInvariantCodeMotionTest,
213        DontHoistLoopVaryingComputationWithAlternatingTuples) {
214   auto m = CreateNewVerifiedModule();
215   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
216   Shape while_shape =
217       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
218 
219   HloComputation* while_body = [&]() {
220     HloComputation::Builder builder(TestName() + ".while_body");
221     HloInstruction* param = builder.AddInstruction(
222         HloInstruction::CreateParameter(0, while_shape, "param"));
223     HloInstruction* gte_0 = builder.AddInstruction(
224         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
225     HloInstruction* gte_1 = builder.AddInstruction(
226         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
227     HloInstruction* add_result =
228         builder.AddInstruction(HloInstruction::CreateBinary(
229             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
230     builder.AddInstruction(
231         HloInstruction::CreateTuple({gte_1, gte_0, add_result}));
232 
233     return m->AddEmbeddedComputation(builder.Build());
234   }();
235 
236   HloComputation::Builder builder(TestName());
237   auto* init_value = builder.AddInstruction(
238       HloInstruction::CreateParameter(0, while_shape, "init_value"));
239   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
240       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
241       init_value));
242 
243   m->AddEntryComputation(builder.Build());
244   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
245                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
246   EXPECT_FALSE(simplified_loop);
247 
248   EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
249 }
250 
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistInstructionWithSideEffects)251 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
252   auto m = CreateNewVerifiedModule();
253   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
254   auto token_shape = ShapeUtil::MakeTokenShape();
255   Shape while_shape =
256       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
257 
258   HloComputation* while_body = [&]() {
259     HloComputation::Builder builder(TestName() + ".while_body");
260     HloInstruction* param = builder.AddInstruction(
261         HloInstruction::CreateParameter(0, while_shape, "param"));
262     HloInstruction* gte_0 = builder.AddInstruction(
263         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
264     HloInstruction* gte_1 = builder.AddInstruction(
265         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
266     HloInstruction* in_token = builder.AddInstruction(
267         HloInstruction::CreateGetTupleElement(token_shape, param, 2));
268     HloInstruction* out_token = builder.AddInstruction(
269         HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, ""));
270     builder.AddInstruction(
271         HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
272 
273     return m->AddEmbeddedComputation(builder.Build());
274   }();
275 
276   HloComputation::Builder builder(TestName());
277   auto* scalar_param = builder.AddInstruction(
278       HloInstruction::CreateParameter(0, scalar_s32, "param"));
279   auto* token = builder.AddInstruction(HloInstruction::CreateToken());
280   auto* init_value = builder.AddInstruction(
281       HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
282   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
283       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
284       init_value));
285   builder.AddInstruction(
286       HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
287   m->AddEntryComputation(builder.Build());
288 
289   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
290                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
291   ASSERT_FALSE(simplified_loop);
292 
293   EXPECT_THAT(while_inst->while_body()->instructions(),
294               Contains(op::Outfeed()));
295 }
296 
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistBitcastAlone)297 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
298   // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the
299   // bitcast either.
300   auto m = CreateNewVerifiedModule();
301   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
302   auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1});
303   auto token_shape = ShapeUtil::MakeTokenShape();
304   Shape while_shape =
305       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
306 
307   HloComputation* while_body = [&]() {
308     HloComputation::Builder builder(TestName() + ".while_body");
309     HloInstruction* param = builder.AddInstruction(
310         HloInstruction::CreateParameter(0, while_shape, "param"));
311     HloInstruction* gte_0 = builder.AddInstruction(
312         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
313     HloInstruction* gte_1 = builder.AddInstruction(
314         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
315     HloInstruction* in_token = builder.AddInstruction(
316         HloInstruction::CreateGetTupleElement(token_shape, param, 2));
317     HloInstruction* bitcast_inst =
318         builder.AddInstruction(HloInstruction::CreateUnary(
319             effective_scalar_s32, HloOpcode::kBitcast, gte_0));
320     HloInstruction* out_token =
321         builder.AddInstruction(HloInstruction::CreateOutfeed(
322             effective_scalar_s32, bitcast_inst, in_token, ""));
323     builder.AddInstruction(
324         HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
325 
326     return m->AddEmbeddedComputation(builder.Build());
327   }();
328 
329   HloComputation::Builder builder(TestName());
330   auto* scalar_param = builder.AddInstruction(
331       HloInstruction::CreateParameter(0, scalar_s32, "param"));
332   auto* token = builder.AddInstruction(HloInstruction::CreateToken());
333   auto* init_value = builder.AddInstruction(
334       HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
335   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
336       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
337       init_value));
338   builder.AddInstruction(
339       HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
340 
341   m->AddEntryComputation(builder.Build());
342 
343   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
344                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
345   EXPECT_FALSE(simplified_loop);
346 
347   EXPECT_THAT(while_inst->while_body()->instructions(),
348               Contains(op::Outfeed()));
349   EXPECT_THAT(while_inst->while_body()->instructions(),
350               Contains(op::Bitcast()));
351 }
352 
TEST_F(WhileLoopInvariantCodeMotionTest,HoistBitcastIfNeeded)353 TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) {
354   // The bitcast's user can be hoisted, so hoist the bitcast too.
355   auto m = CreateNewVerifiedModule();
356   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
357   auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1});
358   Shape while_shape = ShapeUtil::MakeTupleShape(
359       {scalar_s32, effective_scalar_s32, effective_scalar_s32});
360 
361   HloComputation* while_body = [&]() {
362     HloComputation::Builder builder(TestName() + ".while_body");
363     HloInstruction* param = builder.AddInstruction(
364         HloInstruction::CreateParameter(0, while_shape, "param"));
365     HloInstruction* gte_0 = builder.AddInstruction(
366         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
367     HloInstruction* gte_1 = builder.AddInstruction(
368         HloInstruction::CreateGetTupleElement(effective_scalar_s32, param, 1));
369     HloInstruction* bitcast_inst =
370         builder.AddInstruction(HloInstruction::CreateUnary(
371             effective_scalar_s32, HloOpcode::kBitcast, gte_0));
372     HloInstruction* add_inst =
373         builder.AddInstruction(HloInstruction::CreateBinary(
374             effective_scalar_s32, HloOpcode::kAdd, bitcast_inst, gte_1));
375     builder.AddInstruction(
376         HloInstruction::CreateTuple({gte_0, gte_1, add_inst}));
377 
378     return m->AddEmbeddedComputation(builder.Build());
379   }();
380 
381   HloComputation::Builder builder(TestName());
382   auto* init_value = builder.AddInstruction(
383       HloInstruction::CreateParameter(0, while_shape, "init_value"));
384   builder.AddInstruction(HloInstruction::CreateWhile(
385       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
386       init_value));
387 
388   HloComputation* entry_computation = m->AddEntryComputation(builder.Build());
389 
390   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
391                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
392   EXPECT_TRUE(simplified_loop);
393 
394   HloInstruction* transformed_while;
395   FindOnlyWhileInstruction(entry_computation, &transformed_while);
396 
397   EXPECT_THAT(transformed_while->while_body()->instructions(),
398               Each(Not(op::Add())));
399   EXPECT_THAT(transformed_while->while_body()->instructions(),
400               Each(Not(op::Bitcast())));
401   EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
402   EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast()));
403 }
404 
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistControlDependencies)405 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) {
406   auto m = CreateNewVerifiedModule();
407   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
408   Shape while_shape =
409       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
410 
411   HloComputation* while_body;
412   {
413     HloComputation::Builder builder(TestName() + ".while_body");
414     HloInstruction* param = builder.AddInstruction(
415         HloInstruction::CreateParameter(0, while_shape, "param"));
416     HloInstruction* gte_0 = builder.AddInstruction(
417         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
418     HloInstruction* gte_1 = builder.AddInstruction(
419         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
420     HloInstruction* add_result =
421         builder.AddInstruction(HloInstruction::CreateBinary(
422             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
423     TF_ASSERT_OK(param->AddControlDependencyTo(add_result));
424     builder.AddInstruction(
425         HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
426 
427     while_body = m->AddEmbeddedComputation(builder.Build());
428   }
429 
430   HloComputation::Builder builder(TestName());
431   auto* init_value = builder.AddInstruction(
432       HloInstruction::CreateParameter(0, while_shape, "init_value"));
433   builder.AddInstruction(HloInstruction::CreateWhile(
434       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
435       init_value));
436   m->AddEntryComputation(builder.Build());
437   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
438                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
439   EXPECT_FALSE(simplified_loop);
440 }
441 
TEST_F(WhileLoopInvariantCodeMotionTest,BodyHasNonTupleRoot)442 TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
443   auto m = CreateNewVerifiedModule();
444   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
445   Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
446 
447   HloComputation* while_body = [&]() {
448     HloComputation::Builder builder(TestName() + ".passthrough");
449     HloInstruction* param = builder.AddInstruction(
450         HloInstruction::CreateParameter(0, while_shape, "param"));
451     HloComputation* result = m->AddEmbeddedComputation(builder.Build());
452 
453     result->AddInstruction(
454         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
455     return result;
456   }();
457 
458   HloComputation::Builder builder(TestName());
459   auto* init_value = builder.AddInstruction(
460       HloInstruction::CreateParameter(0, while_shape, "init_value"));
461   builder.AddInstruction(HloInstruction::CreateWhile(
462       while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
463       init_value));
464   m->AddEntryComputation(builder.Build());
465   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
466                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
467   EXPECT_FALSE(simplified_loop);
468 }
469 
470 const char* const kConstantHoistingTestCase = R"(
471 HloModule ModuleWithWhile
472 
473 body {
474   p_body = (f32[2]{0}) parameter(0)
475   p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
476   const = f32[2]{0} constant({3, 4})
477   add.0 = f32[2]{0} add(p_body.1, const)
478   ROOT root = (f32[2]{0}) tuple(add.0)
479 }
480 
481 condition {
482   p_cond = (f32[2]{0}) parameter(0)
483   ROOT result = pred[] constant(true)
484 }
485 
486 ENTRY entry {
487   const_0 = f32[2]{0} constant({1, 2})
488   while_init = (f32[2]{0}) tuple(const_0)
489   ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body
490 }
491 )";
492 
TEST_F(WhileLoopInvariantCodeMotionTest,HoistsConstantWhenAsked)493 TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) {
494   auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie();
495 
496   TF_ASSERT_OK_AND_ASSIGN(
497       bool simplified_loop,
498       WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(m.get()));
499   EXPECT_TRUE(simplified_loop);
500 
501   HloComputation* while_body = m->GetComputationWithName("wide.body");
502   ASSERT_NE(while_body, nullptr);
503 
504   // We expect the while body to be the equivalent of:
505   //
506   //  wide.body {
507   //    wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0)
508   //    get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0
509   //    tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1)
510   //    get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0
511   //    get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1
512   //    add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7)
513   //    tuple.3 = (f32[2]{0}) tuple(add.1)
514   //    get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0
515   //    get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1
516   //    ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8,
517   //                                                get-tuple-element.9)
518   //  }
519 
520   auto wide_param_1 = op::Parameter(0);
521   auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0);
522   auto tuple_1 = op::Tuple(get_tuple_element_1);
523   auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0);
524   auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1);
525   auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7);
526   auto tuple_3 = op::Tuple(add_1);
527   auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0);
528   auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1);
529   auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9);
530 
531   EXPECT_THAT(while_body->root_instruction(), tuple_4);
532 }
533 
TEST_F(WhileLoopInvariantCodeMotionTest,DoesNotHoistConstantByDefault)534 TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) {
535   auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie();
536 
537   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
538                           WhileLoopInvariantCodeMotion{}.Run(m.get()));
539   EXPECT_FALSE(simplified_loop);
540 }
541 
TEST_F(WhileLoopInvariantCodeMotionTest,DoNotHoistOutOfSingleIteration)542 TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) {
543   const char* const kHloModule = R"(
544     HloModule ModuleWithWhile
545 
546     body {
547       p_body = (f32[2], f32[2], f32[2], s32[]) parameter(0)
548       val.0 = f32[2] get-tuple-element(p_body), index=0
549       val.1 = f32[2] get-tuple-element(p_body), index=1
550       add = f32[2] add(val.0, val.1)
551       const = s32[] constant(-1)
552       ROOT root = (f32[2], f32[2], f32[2], s32[]) tuple(val.0, val.1, add, const)
553     }
554 
555     condition {
556       p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0)
557       gte = s32[] get-tuple-element(p_cond), index=3
558       const = s32[] constant(42)
559       ROOT result = pred[] compare(gte, const), direction=EQ
560     }
561 
562     ENTRY entry {
563       param.0 = f32[2] parameter(0)
564       param.1 = s32[] parameter(1)
565       while_init = (f32[2], f32[2], f32[2], s32[]) tuple(param.0, param.0, param.0, param.1)
566       ROOT while = (f32[2], f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body
567     })";
568   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
569                           ParseAndReturnVerifiedModule(kHloModule));
570 
571   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
572                           WhileLoopInvariantCodeMotion{}.Run(module.get()));
573   EXPECT_FALSE(simplified_loop);
574 }
575 
576 const char* const kInflatingTestCase = R"(
577 HloModule ModuleWithWhile
578 
579 mul {
580   lhs = f32[] parameter(0)
581   rhs = f32[] parameter(1)
582   ROOT mul = f32[] multiply(lhs, rhs)
583 }
584 
585 body {
586   p_body = (f32[]) parameter(0)
587   iota = f32[1024, 1024] iota(), iota_dimension=0
588   add = f32[1024, 1024] add(iota, iota)
589   constant = f32[] constant(1.0)
590   reduce = f32[] reduce(f32[1024, 1024] add, f32[] constant), dimensions={0,1}, to_apply=mul
591   ROOT root = (f32[]) tuple(reduce)
592 }
593 
594 condition {
595   p_cond = (f32[]) parameter(0)
596   ROOT result = pred[] constant(true)
597 }
598 
599 ENTRY entry {
600   param = f32[] parameter(0)
601   while_init = (f32[]) tuple(param)
602   ROOT while = (f32[]) while(while_init), condition=condition, body=body
603 }
604 )";
605 
TEST_F(WhileLoopInvariantCodeMotionTest,HoistsInflatingByDefault)606 TEST_F(WhileLoopInvariantCodeMotionTest, HoistsInflatingByDefault) {
607   auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie();
608 
609   TF_ASSERT_OK_AND_ASSIGN(
610       bool simplified_loop,
611       WhileLoopInvariantCodeMotion(/*hoist_constants=*/true).Run(m.get()));
612   EXPECT_TRUE(simplified_loop);
613 
614   HloComputation* while_body = m->GetComputationWithName("wide.body");
615   ASSERT_NE(while_body, nullptr);
616   EXPECT_THAT(while_body->instructions(), Not(Contains(op::Iota())));
617 }
618 
TEST_F(WhileLoopInvariantCodeMotionTest,NoHoistInflating)619 TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) {
620   auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie();
621 
622   TF_ASSERT_OK_AND_ASSIGN(
623       bool simplified_loop,
624       WhileLoopInvariantCodeMotion(/*hoist_constants=*/true,
625                                    /*hoist_size_inflating_ops=*/false)
626           .Run(m.get()));
627   EXPECT_FALSE(simplified_loop);
628 }
629 
630 }  // namespace
631 }  // namespace xla
632