1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_parser.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23
24 namespace xla {
25 namespace {
26
27 namespace op = xla::testing::opcode_matchers;
28
29 class WhileLoopInvariantCodeMotionTest : public HloTestBase {
30 public:
31 // Makes a computation which has one parameter, of the given shape, and always
32 // returns PRED[]{true}. This is useful as a dummy loop condition.
33 HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
34 HloModule* module);
35 };
36
FindOnlyWhileInstruction(HloComputation * computation,HloInstruction ** while_instruction)37 static void FindOnlyWhileInstruction(HloComputation* computation,
38 HloInstruction** while_instruction) {
39 *while_instruction = nullptr;
40 for (auto* instr : computation->instructions()) {
41 if (instr->opcode() == HloOpcode::kWhile) {
42 ASSERT_EQ(*while_instruction, nullptr);
43 *while_instruction = instr;
44 }
45 }
46
47 ASSERT_NE(*while_instruction, nullptr);
48 }
49
MakeAlwaysTrueComputation(const Shape & param_shape,HloModule * module)50 HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation(
51 const Shape& param_shape, HloModule* module) {
52 HloComputation::Builder builder(TestName() + ".always_true");
53 builder.AddInstruction(
54 HloInstruction::CreateParameter(0, param_shape, "param"));
55 builder.AddInstruction(
56 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
57 return module->AddEmbeddedComputation(builder.Build());
58 }
59
TEST_F(WhileLoopInvariantCodeMotionTest,HoistOneInvariantOperation)60 TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) {
61 auto m = CreateNewVerifiedModule();
62 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
63 Shape while_shape =
64 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
65
66 HloComputation* while_body = [&]() {
67 HloComputation::Builder builder(TestName() + ".while_body");
68 HloInstruction* param = builder.AddInstruction(
69 HloInstruction::CreateParameter(0, while_shape, "param"));
70 HloInstruction* gte_0 = builder.AddInstruction(
71 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
72 HloInstruction* gte_1 = builder.AddInstruction(
73 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
74 HloInstruction* add_result =
75 builder.AddInstruction(HloInstruction::CreateBinary(
76 scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
77 builder.AddInstruction(
78 HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
79
80 return m->AddEmbeddedComputation(builder.Build());
81 }();
82
83 HloComputation::Builder builder(TestName());
84 auto* init_value = builder.AddInstruction(
85 HloInstruction::CreateParameter(0, while_shape, "init_value"));
86 builder.AddInstruction(HloInstruction::CreateWhile(
87 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
88 init_value));
89 HloComputation* entry_computation = m->AddEntryComputation(builder.Build());
90 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
91 WhileLoopInvariantCodeMotion{}.Run(m.get()));
92 EXPECT_TRUE(simplified_loop);
93
94 HloInstruction* transformed_while;
95 FindOnlyWhileInstruction(entry_computation, &transformed_while);
96
97 EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
98 EXPECT_THAT(transformed_while->while_body()->instructions(),
99 Each(Not(op::Add())));
100 }
101
TEST_F(WhileLoopInvariantCodeMotionTest,HoistInvariantOperationTree)102 TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) {
103 auto m = CreateNewVerifiedModule();
104 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
105 Shape while_shape =
106 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
107
108 HloComputation* while_body = [&]() {
109 HloComputation::Builder builder(TestName() + ".while_body");
110 HloInstruction* param = builder.AddInstruction(
111 HloInstruction::CreateParameter(0, while_shape, "param"));
112 HloInstruction* gte_0 = builder.AddInstruction(
113 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
114 HloInstruction* gte_1 = builder.AddInstruction(
115 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
116 HloInstruction* gte_2_loop_variant = builder.AddInstruction(
117 HloInstruction::CreateGetTupleElement(scalar_s32, param, 2));
118
119 HloInstruction* add_result =
120 builder.AddInstruction(HloInstruction::CreateBinary(
121 scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
122 HloInstruction* mul_result =
123 builder.AddInstruction(HloInstruction::CreateBinary(
124 scalar_s32, HloOpcode::kMultiply, add_result, gte_1));
125 HloInstruction* negate_result =
126 builder.AddInstruction(HloInstruction::CreateUnary(
127 scalar_s32, HloOpcode::kNegate, mul_result));
128 HloInstruction* constant = builder.AddInstruction(
129 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(4)));
130 HloInstruction* sub_result =
131 builder.AddInstruction(HloInstruction::CreateBinary(
132 scalar_s32, HloOpcode::kSubtract, negate_result, constant));
133 HloInstruction* divide_result =
134 builder.AddInstruction(HloInstruction::CreateBinary(
135 scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant));
136 builder.AddInstruction(
137 HloInstruction::CreateTuple({gte_0, gte_1, divide_result}));
138
139 return m->AddEmbeddedComputation(builder.Build());
140 }();
141
142 HloComputation::Builder builder(TestName());
143 auto* init_value = builder.AddInstruction(
144 HloInstruction::CreateParameter(0, while_shape, "init_value"));
145 builder.AddInstruction(HloInstruction::CreateWhile(
146 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
147 init_value));
148 HloComputation* entry_computation = m->AddEntryComputation(builder.Build());
149 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
150 WhileLoopInvariantCodeMotion{}.Run(m.get()));
151 EXPECT_TRUE(simplified_loop);
152
153 HloInstruction* transformed_while;
154 FindOnlyWhileInstruction(entry_computation, &transformed_while);
155
156 EXPECT_THAT(entry_computation->instructions(),
157 AllOf(Contains(op::Add()), Contains(op::Multiply()),
158 Contains(op::Negate()), Contains(op::Subtract()),
159 Contains(op::Constant()),
160
161 // The division had a loop varying operand so that better
162 // not be hoisted.
163 Not(Contains(op::Divide()))));
164
165 EXPECT_THAT(transformed_while->while_body()->instructions(),
166 Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(),
167 op::Subtract(), op::Constant()))));
168
169 EXPECT_THAT(transformed_while->while_body()->instructions(),
170 Contains(op::Divide()));
171 }
172
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistTriviallyLoopVaryingComputation)173 TEST_F(WhileLoopInvariantCodeMotionTest,
174 DontHoistTriviallyLoopVaryingComputation) {
175 // Basic negative test: the add expression is not loop invariant.
176 auto m = CreateNewVerifiedModule();
177 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
178 Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
179
180 HloComputation* while_body = [&]() {
181 HloComputation::Builder builder(TestName() + ".while_body");
182 HloInstruction* param = builder.AddInstruction(
183 HloInstruction::CreateParameter(0, while_shape, "param"));
184 HloInstruction* gte_0 = builder.AddInstruction(
185 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
186 HloInstruction* gte_1 = builder.AddInstruction(
187 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
188 HloInstruction* add_result =
189 builder.AddInstruction(HloInstruction::CreateBinary(
190 scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
191 builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result}));
192
193 return m->AddEmbeddedComputation(builder.Build());
194 }();
195
196 HloComputation::Builder builder(TestName());
197 auto* init_value = builder.AddInstruction(
198 HloInstruction::CreateParameter(0, while_shape, "init_value"));
199 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
200 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
201 init_value));
202
203 m->AddEntryComputation(builder.Build());
204
205 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
206 WhileLoopInvariantCodeMotion{}.Run(m.get()));
207 EXPECT_FALSE(simplified_loop);
208
209 EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
210 }
211
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistLoopVaryingComputationWithAlternatingTuples)212 TEST_F(WhileLoopInvariantCodeMotionTest,
213 DontHoistLoopVaryingComputationWithAlternatingTuples) {
214 auto m = CreateNewVerifiedModule();
215 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
216 Shape while_shape =
217 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
218
219 HloComputation* while_body = [&]() {
220 HloComputation::Builder builder(TestName() + ".while_body");
221 HloInstruction* param = builder.AddInstruction(
222 HloInstruction::CreateParameter(0, while_shape, "param"));
223 HloInstruction* gte_0 = builder.AddInstruction(
224 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
225 HloInstruction* gte_1 = builder.AddInstruction(
226 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
227 HloInstruction* add_result =
228 builder.AddInstruction(HloInstruction::CreateBinary(
229 scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
230 builder.AddInstruction(
231 HloInstruction::CreateTuple({gte_1, gte_0, add_result}));
232
233 return m->AddEmbeddedComputation(builder.Build());
234 }();
235
236 HloComputation::Builder builder(TestName());
237 auto* init_value = builder.AddInstruction(
238 HloInstruction::CreateParameter(0, while_shape, "init_value"));
239 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
240 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
241 init_value));
242
243 m->AddEntryComputation(builder.Build());
244 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
245 WhileLoopInvariantCodeMotion{}.Run(m.get()));
246 EXPECT_FALSE(simplified_loop);
247
248 EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
249 }
250
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistInstructionWithSideEffects)251 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
252 auto m = CreateNewVerifiedModule();
253 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
254 auto token_shape = ShapeUtil::MakeTokenShape();
255 Shape while_shape =
256 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
257
258 HloComputation* while_body = [&]() {
259 HloComputation::Builder builder(TestName() + ".while_body");
260 HloInstruction* param = builder.AddInstruction(
261 HloInstruction::CreateParameter(0, while_shape, "param"));
262 HloInstruction* gte_0 = builder.AddInstruction(
263 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
264 HloInstruction* gte_1 = builder.AddInstruction(
265 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
266 HloInstruction* in_token = builder.AddInstruction(
267 HloInstruction::CreateGetTupleElement(token_shape, param, 2));
268 HloInstruction* out_token = builder.AddInstruction(
269 HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, ""));
270 builder.AddInstruction(
271 HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
272
273 return m->AddEmbeddedComputation(builder.Build());
274 }();
275
276 HloComputation::Builder builder(TestName());
277 auto* scalar_param = builder.AddInstruction(
278 HloInstruction::CreateParameter(0, scalar_s32, "param"));
279 auto* token = builder.AddInstruction(HloInstruction::CreateToken());
280 auto* init_value = builder.AddInstruction(
281 HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
282 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
283 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
284 init_value));
285 builder.AddInstruction(
286 HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
287 m->AddEntryComputation(builder.Build());
288
289 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
290 WhileLoopInvariantCodeMotion{}.Run(m.get()));
291 ASSERT_FALSE(simplified_loop);
292
293 EXPECT_THAT(while_inst->while_body()->instructions(),
294 Contains(op::Outfeed()));
295 }
296
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistBitcastAlone)297 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
298 // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the
299 // bitcast either.
300 auto m = CreateNewVerifiedModule();
301 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
302 auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1});
303 auto token_shape = ShapeUtil::MakeTokenShape();
304 Shape while_shape =
305 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
306
307 HloComputation* while_body = [&]() {
308 HloComputation::Builder builder(TestName() + ".while_body");
309 HloInstruction* param = builder.AddInstruction(
310 HloInstruction::CreateParameter(0, while_shape, "param"));
311 HloInstruction* gte_0 = builder.AddInstruction(
312 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
313 HloInstruction* gte_1 = builder.AddInstruction(
314 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
315 HloInstruction* in_token = builder.AddInstruction(
316 HloInstruction::CreateGetTupleElement(token_shape, param, 2));
317 HloInstruction* bitcast_inst =
318 builder.AddInstruction(HloInstruction::CreateUnary(
319 effective_scalar_s32, HloOpcode::kBitcast, gte_0));
320 HloInstruction* out_token =
321 builder.AddInstruction(HloInstruction::CreateOutfeed(
322 effective_scalar_s32, bitcast_inst, in_token, ""));
323 builder.AddInstruction(
324 HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
325
326 return m->AddEmbeddedComputation(builder.Build());
327 }();
328
329 HloComputation::Builder builder(TestName());
330 auto* scalar_param = builder.AddInstruction(
331 HloInstruction::CreateParameter(0, scalar_s32, "param"));
332 auto* token = builder.AddInstruction(HloInstruction::CreateToken());
333 auto* init_value = builder.AddInstruction(
334 HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
335 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
336 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
337 init_value));
338 builder.AddInstruction(
339 HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
340
341 m->AddEntryComputation(builder.Build());
342
343 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
344 WhileLoopInvariantCodeMotion{}.Run(m.get()));
345 EXPECT_FALSE(simplified_loop);
346
347 EXPECT_THAT(while_inst->while_body()->instructions(),
348 Contains(op::Outfeed()));
349 EXPECT_THAT(while_inst->while_body()->instructions(),
350 Contains(op::Bitcast()));
351 }
352
TEST_F(WhileLoopInvariantCodeMotionTest,HoistBitcastIfNeeded)353 TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) {
354 // The bitcast's user can be hoisted, so hoist the bitcast too.
355 auto m = CreateNewVerifiedModule();
356 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
357 auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1});
358 Shape while_shape = ShapeUtil::MakeTupleShape(
359 {scalar_s32, effective_scalar_s32, effective_scalar_s32});
360
361 HloComputation* while_body = [&]() {
362 HloComputation::Builder builder(TestName() + ".while_body");
363 HloInstruction* param = builder.AddInstruction(
364 HloInstruction::CreateParameter(0, while_shape, "param"));
365 HloInstruction* gte_0 = builder.AddInstruction(
366 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
367 HloInstruction* gte_1 = builder.AddInstruction(
368 HloInstruction::CreateGetTupleElement(effective_scalar_s32, param, 1));
369 HloInstruction* bitcast_inst =
370 builder.AddInstruction(HloInstruction::CreateUnary(
371 effective_scalar_s32, HloOpcode::kBitcast, gte_0));
372 HloInstruction* add_inst =
373 builder.AddInstruction(HloInstruction::CreateBinary(
374 effective_scalar_s32, HloOpcode::kAdd, bitcast_inst, gte_1));
375 builder.AddInstruction(
376 HloInstruction::CreateTuple({gte_0, gte_1, add_inst}));
377
378 return m->AddEmbeddedComputation(builder.Build());
379 }();
380
381 HloComputation::Builder builder(TestName());
382 auto* init_value = builder.AddInstruction(
383 HloInstruction::CreateParameter(0, while_shape, "init_value"));
384 builder.AddInstruction(HloInstruction::CreateWhile(
385 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
386 init_value));
387
388 HloComputation* entry_computation = m->AddEntryComputation(builder.Build());
389
390 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
391 WhileLoopInvariantCodeMotion{}.Run(m.get()));
392 EXPECT_TRUE(simplified_loop);
393
394 HloInstruction* transformed_while;
395 FindOnlyWhileInstruction(entry_computation, &transformed_while);
396
397 EXPECT_THAT(transformed_while->while_body()->instructions(),
398 Each(Not(op::Add())));
399 EXPECT_THAT(transformed_while->while_body()->instructions(),
400 Each(Not(op::Bitcast())));
401 EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
402 EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast()));
403 }
404
TEST_F(WhileLoopInvariantCodeMotionTest,DontHoistControlDependencies)405 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) {
406 auto m = CreateNewVerifiedModule();
407 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
408 Shape while_shape =
409 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
410
411 HloComputation* while_body;
412 {
413 HloComputation::Builder builder(TestName() + ".while_body");
414 HloInstruction* param = builder.AddInstruction(
415 HloInstruction::CreateParameter(0, while_shape, "param"));
416 HloInstruction* gte_0 = builder.AddInstruction(
417 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
418 HloInstruction* gte_1 = builder.AddInstruction(
419 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
420 HloInstruction* add_result =
421 builder.AddInstruction(HloInstruction::CreateBinary(
422 scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
423 TF_ASSERT_OK(param->AddControlDependencyTo(add_result));
424 builder.AddInstruction(
425 HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
426
427 while_body = m->AddEmbeddedComputation(builder.Build());
428 }
429
430 HloComputation::Builder builder(TestName());
431 auto* init_value = builder.AddInstruction(
432 HloInstruction::CreateParameter(0, while_shape, "init_value"));
433 builder.AddInstruction(HloInstruction::CreateWhile(
434 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
435 init_value));
436 m->AddEntryComputation(builder.Build());
437 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
438 WhileLoopInvariantCodeMotion{}.Run(m.get()));
439 EXPECT_FALSE(simplified_loop);
440 }
441
TEST_F(WhileLoopInvariantCodeMotionTest,BodyHasNonTupleRoot)442 TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
443 auto m = CreateNewVerifiedModule();
444 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
445 Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
446
447 HloComputation* while_body = [&]() {
448 HloComputation::Builder builder(TestName() + ".passthrough");
449 HloInstruction* param = builder.AddInstruction(
450 HloInstruction::CreateParameter(0, while_shape, "param"));
451 HloComputation* result = m->AddEmbeddedComputation(builder.Build());
452
453 result->AddInstruction(
454 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
455 return result;
456 }();
457
458 HloComputation::Builder builder(TestName());
459 auto* init_value = builder.AddInstruction(
460 HloInstruction::CreateParameter(0, while_shape, "init_value"));
461 builder.AddInstruction(HloInstruction::CreateWhile(
462 while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body,
463 init_value));
464 m->AddEntryComputation(builder.Build());
465 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
466 WhileLoopInvariantCodeMotion{}.Run(m.get()));
467 EXPECT_FALSE(simplified_loop);
468 }
469
470 const char* const kConstantHoistingTestCase = R"(
471 HloModule ModuleWithWhile
472
473 body {
474 p_body = (f32[2]{0}) parameter(0)
475 p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
476 const = f32[2]{0} constant({3, 4})
477 add.0 = f32[2]{0} add(p_body.1, const)
478 ROOT root = (f32[2]{0}) tuple(add.0)
479 }
480
481 condition {
482 p_cond = (f32[2]{0}) parameter(0)
483 ROOT result = pred[] constant(true)
484 }
485
486 ENTRY entry {
487 const_0 = f32[2]{0} constant({1, 2})
488 while_init = (f32[2]{0}) tuple(const_0)
489 ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body
490 }
491 )";
492
TEST_F(WhileLoopInvariantCodeMotionTest,HoistsConstantWhenAsked)493 TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) {
494 auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie();
495
496 TF_ASSERT_OK_AND_ASSIGN(
497 bool simplified_loop,
498 WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(m.get()));
499 EXPECT_TRUE(simplified_loop);
500
501 HloComputation* while_body = m->GetComputationWithName("wide.body");
502 ASSERT_NE(while_body, nullptr);
503
504 // We expect the while body to be the equivalent of:
505 //
506 // wide.body {
507 // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0)
508 // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0
509 // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1)
510 // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0
511 // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1
512 // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7)
513 // tuple.3 = (f32[2]{0}) tuple(add.1)
514 // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0
515 // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1
516 // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8,
517 // get-tuple-element.9)
518 // }
519
520 auto wide_param_1 = op::Parameter(0);
521 auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0);
522 auto tuple_1 = op::Tuple(get_tuple_element_1);
523 auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0);
524 auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1);
525 auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7);
526 auto tuple_3 = op::Tuple(add_1);
527 auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0);
528 auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1);
529 auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9);
530
531 EXPECT_THAT(while_body->root_instruction(), tuple_4);
532 }
533
TEST_F(WhileLoopInvariantCodeMotionTest,DoesNotHoistConstantByDefault)534 TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) {
535 auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie();
536
537 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
538 WhileLoopInvariantCodeMotion{}.Run(m.get()));
539 EXPECT_FALSE(simplified_loop);
540 }
541
TEST_F(WhileLoopInvariantCodeMotionTest,DoNotHoistOutOfSingleIteration)542 TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) {
543 const char* const kHloModule = R"(
544 HloModule ModuleWithWhile
545
546 body {
547 p_body = (f32[2], f32[2], f32[2], s32[]) parameter(0)
548 val.0 = f32[2] get-tuple-element(p_body), index=0
549 val.1 = f32[2] get-tuple-element(p_body), index=1
550 add = f32[2] add(val.0, val.1)
551 const = s32[] constant(-1)
552 ROOT root = (f32[2], f32[2], f32[2], s32[]) tuple(val.0, val.1, add, const)
553 }
554
555 condition {
556 p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0)
557 gte = s32[] get-tuple-element(p_cond), index=3
558 const = s32[] constant(42)
559 ROOT result = pred[] compare(gte, const), direction=EQ
560 }
561
562 ENTRY entry {
563 param.0 = f32[2] parameter(0)
564 param.1 = s32[] parameter(1)
565 while_init = (f32[2], f32[2], f32[2], s32[]) tuple(param.0, param.0, param.0, param.1)
566 ROOT while = (f32[2], f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body
567 })";
568 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
569 ParseAndReturnVerifiedModule(kHloModule));
570
571 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
572 WhileLoopInvariantCodeMotion{}.Run(module.get()));
573 EXPECT_FALSE(simplified_loop);
574 }
575
576 const char* const kInflatingTestCase = R"(
577 HloModule ModuleWithWhile
578
579 mul {
580 lhs = f32[] parameter(0)
581 rhs = f32[] parameter(1)
582 ROOT mul = f32[] multiply(lhs, rhs)
583 }
584
585 body {
586 p_body = (f32[]) parameter(0)
587 iota = f32[1024, 1024] iota(), iota_dimension=0
588 add = f32[1024, 1024] add(iota, iota)
589 constant = f32[] constant(1.0)
590 reduce = f32[] reduce(f32[1024, 1024] add, f32[] constant), dimensions={0,1}, to_apply=mul
591 ROOT root = (f32[]) tuple(reduce)
592 }
593
594 condition {
595 p_cond = (f32[]) parameter(0)
596 ROOT result = pred[] constant(true)
597 }
598
599 ENTRY entry {
600 param = f32[] parameter(0)
601 while_init = (f32[]) tuple(param)
602 ROOT while = (f32[]) while(while_init), condition=condition, body=body
603 }
604 )";
605
TEST_F(WhileLoopInvariantCodeMotionTest,HoistsInflatingByDefault)606 TEST_F(WhileLoopInvariantCodeMotionTest, HoistsInflatingByDefault) {
607 auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie();
608
609 TF_ASSERT_OK_AND_ASSIGN(
610 bool simplified_loop,
611 WhileLoopInvariantCodeMotion(/*hoist_constants=*/true).Run(m.get()));
612 EXPECT_TRUE(simplified_loop);
613
614 HloComputation* while_body = m->GetComputationWithName("wide.body");
615 ASSERT_NE(while_body, nullptr);
616 EXPECT_THAT(while_body->instructions(), Not(Contains(op::Iota())));
617 }
618
TEST_F(WhileLoopInvariantCodeMotionTest,NoHoistInflating)619 TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) {
620 auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie();
621
622 TF_ASSERT_OK_AND_ASSIGN(
623 bool simplified_loop,
624 WhileLoopInvariantCodeMotion(/*hoist_constants=*/true,
625 /*hoist_size_inflating_ops=*/false)
626 .Run(m.get()));
627 EXPECT_FALSE(simplified_loop);
628 }
629
630 } // namespace
631 } // namespace xla
632