1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_parser.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22
23 namespace xla {
24 namespace {
25
26 namespace op = xla::testing::opcode_matchers;
27 using ::testing::_;
28
29 class WhileLoopConstantSinkingTest : public ::testing::Test {};
30
TEST_F(WhileLoopConstantSinkingTest,SinkOneConstant)31 TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) {
32 const char* const hlo_string = R"(
33 HloModule ModuleWithWhile
34
35 body {
36 p_body = (f32[2],f32[2]) parameter(0)
37 p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
38 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
39
40 add.0 = f32[2] add(p_body.0, p_body.1)
41 ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1)
42 }
43
44 condition {
45 p_cond = (f32[2],f32[2]) parameter(0)
46 ROOT result = pred[] constant(true)
47 }
48
49 ENTRY entry {
50 const_0 = f32[2] constant({1, 2})
51 const_1 = f32[2] constant({2, 1})
52 while_init = (f32[2],f32[2]) tuple(const_0, const_1)
53 ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body
54 }
55 )";
56
57 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
58 ParseHloString(hlo_string));
59
60 TF_ASSERT_OK_AND_ASSIGN(bool changed,
61 WhileLoopConstantSinking{}.Run(module.get()));
62 ASSERT_TRUE(changed);
63
64 auto* while_body = module->GetComputationWithName("body");
65 EXPECT_THAT(while_body->root_instruction(),
66 op::Tuple(op::Add(_, op::Constant()), _));
67 }
68
TEST_F(WhileLoopConstantSinkingTest,KeepConstantsLoopInvariant)69 TEST_F(WhileLoopConstantSinkingTest, KeepConstantsLoopInvariant) {
70 const char* const hlo_string = R"(
71 HloModule ModuleWithWhile
72
73 body {
74 p_body = (f32[2],f32[2],f32[2]) parameter(0)
75 p_body.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=0
76 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=1
77 p_body.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=2
78
79 add.0 = f32[2] add(p_body.1, p_body.2)
80 ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_body.1, p_body.2)
81 }
82
83 condition {
84 p_cond = (f32[2],f32[2],f32[2]) parameter(0)
85 ROOT result = pred[] constant(true)
86 }
87
88 ENTRY entry {
89 const_0 = f32[2] constant({1, 2})
90 const_1 = f32[2] constant({2, 1})
91 const_2 = f32[2] constant({3, 1})
92 while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2)
93 ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body
94 }
95 )";
96
97 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
98 ParseHloString(hlo_string));
99
100 TF_ASSERT_OK_AND_ASSIGN(bool changed,
101 WhileLoopConstantSinking{}.Run(module.get()));
102 ASSERT_TRUE(changed);
103
104 auto* while_body = module->GetComputationWithName("body");
105 EXPECT_THAT(while_body->root_instruction(),
106 op::Tuple(op::Add(op::Constant(), op::Constant()),
107 op::GetTupleElement(op::Parameter(0)),
108 op::GetTupleElement(op::Parameter(0))));
109 }
110
TEST_F(WhileLoopConstantSinkingTest,TupleShapedConstants)111 TEST_F(WhileLoopConstantSinkingTest, TupleShapedConstants) {
112 const char* const hlo_string = R"(
113 HloModule ModuleWithWhile
114
115 body {
116 p_b = (f32[2],(f32[2],f32[2])) parameter(0)
117 p_b.0 = f32[2] get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=0
118 p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1
119
120 p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0
121
122 ROOT root = (f32[2],f32[2],f32[2]) tuple(p_b.1.1, p_b.1)
123 }
124
125 condition {
126 p_cond = (f32[2],(f32[2],f32[2])) parameter(0)
127 ROOT result = pred[] constant(true)
128 }
129
130 ENTRY entry {
131 const_0 = f32[2] constant({1, 2})
132 const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1}))
133 while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1)
134 ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body
135 }
136 )";
137
138 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
139 ParseHloString(hlo_string));
140
141 TF_ASSERT_OK_AND_ASSIGN(bool changed,
142 WhileLoopConstantSinking{}.Run(module.get()));
143 ASSERT_TRUE(changed);
144
145 auto* while_body = module->GetComputationWithName("body");
146 EXPECT_THAT(while_body->root_instruction(),
147 op::Tuple(op::GetTupleElement(op::Constant(), 0),
148 op::GetTupleElement(op::Parameter(0))));
149 }
150
TEST_F(WhileLoopConstantSinkingTest,DuplicateGTEs)151 TEST_F(WhileLoopConstantSinkingTest, DuplicateGTEs) {
152 // This test shows that the pass fails to optimize non-canonical IR.
153 //
154 // Even though the input IR has a constant value for p_b.2.dup,
155 // WhileLoopConstantSinking doesn't try to detect this. Instead, it relies on
156 // prior runs of HLO CSE to have commoned these identical GTE instructions.
157
158 const char* const hlo_string = R"(
159 HloModule ModuleWithWhile
160
161 body {
162 p_b = (f32[2],f32[2],f32[2]) parameter(0)
163
164 p_b.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=1
165 p_b.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2
166 p_b.2.dup = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2
167
168 add.0 = f32[2] add(p_b.1, p_b.2.dup)
169 ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_b.1, p_b.2)
170 }
171
172 condition {
173 p_cond = (f32[2],f32[2],f32[2]) parameter(0)
174 ROOT result = pred[] constant(true)
175 }
176
177 ENTRY entry {
178 const_0 = f32[2] constant({1, 2})
179 const_1 = f32[2] constant({2, 1})
180 const_2 = f32[2] constant({3, 1})
181 while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2)
182 ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body
183 }
184 )";
185
186 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
187 ParseHloString(hlo_string));
188
189 TF_ASSERT_OK_AND_ASSIGN(bool changed,
190 WhileLoopConstantSinking{}.Run(module.get()));
191 ASSERT_TRUE(changed);
192
193 auto* while_body = module->GetComputationWithName("body");
194 EXPECT_THAT(while_body->root_instruction(),
195 op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())),
196 op::GetTupleElement(op::Parameter(0)),
197 op::GetTupleElement(op::Parameter(0))));
198 }
199
TEST_F(WhileLoopConstantSinkingTest,DontCreateDeadConstant)200 TEST_F(WhileLoopConstantSinkingTest, DontCreateDeadConstant) {
201 const char* const hlo_string = R"(
202 HloModule ModuleWithWhile
203
204 body {
205 p_body = (f32[2],f32[2]) parameter(0)
206 p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
207 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
208
209 token0 = token[] after-all()
210 outfeed = token[] outfeed(p_body.0, token0)
211 ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1)
212 }
213
214 condition {
215 p_cond = (f32[2],f32[2]) parameter(0)
216 ROOT result = pred[] constant(true)
217 }
218
219 ENTRY entry {
220 const_0 = f32[2] constant({1, 2})
221 const_1 = f32[2] constant({2, 1})
222 while_init = (f32[2],f32[2]) tuple(const_0, const_1)
223 ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition,
224 body=body
225 }
226 )";
227
228 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
229 ParseHloString(hlo_string));
230
231 TF_ASSERT_OK_AND_ASSIGN(bool changed,
232 WhileLoopConstantSinking{}.Run(module.get()));
233 ASSERT_TRUE(changed);
234
235 auto* while_body = module->GetComputationWithName("body");
236 EXPECT_THAT(while_body->root_instruction(),
237 op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
238 op::GetTupleElement()));
239 for (const HloInstruction* inst : while_body->instructions()) {
240 if (inst->opcode() == HloOpcode::kConstant) {
241 EXPECT_GT(inst->user_count(), 0);
242 }
243 }
244 }
245
TEST_F(WhileLoopConstantSinkingTest,ConditionalSinkConstant)246 TEST_F(WhileLoopConstantSinkingTest, ConditionalSinkConstant) {
247 const char* const hlo_string = R"(
248 HloModule ModuleWithWhile
249
250 body {
251 p_body = (f32[],f32[]) parameter(0)
252 p_body.0 = f32[] get-tuple-element((f32[],f32[]) p_body), index=0
253 const = f32[] constant(1)
254 add = f32[] add(p_body.0, const)
255 p_body.1 = f32[] get-tuple-element((f32[],f32[]) p_body), index=1
256 ROOT root = (f32[],f32[]) tuple(add, p_body.1)
257 }
258
259 condition {
260 p_cond = (f32[],f32[]) parameter(0)
261 p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0
262 p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1
263 ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
264 }
265
266 ENTRY entry {
267 const_0 = f32[] constant(0)
268 const_1 = f32[] constant(10)
269 while_init = (f32[],f32[]) tuple(const_0, const_1)
270 ROOT while = (f32[],f32[]) while(while_init), condition=condition, body=body
271 }
272 )";
273
274 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
275 ParseHloString(hlo_string));
276
277 TF_ASSERT_OK_AND_ASSIGN(bool changed,
278 WhileLoopConstantSinking{}.Run(module.get()));
279 ASSERT_TRUE(changed);
280
281 auto* while_condition = module->GetComputationWithName("condition");
282 EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant()));
283 }
284
TEST_F(WhileLoopConstantSinkingTest,ConditionalTupleShapedConstants)285 TEST_F(WhileLoopConstantSinkingTest, ConditionalTupleShapedConstants) {
286 const char* const hlo_string = R"(
287 HloModule ModuleWithWhile
288
289 body {
290 p_b = (f32[],(f32[],f32[])) parameter(0)
291 p_b.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_b), index=0
292 p_b.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_b), index=1
293 p_b.1.0 = f32[] get-tuple-element((f32[],f32[]) p_b.1), index=0
294 add = f32[] add(p_b.0, p_b.1.0)
295 ROOT root = (f32[],(f32[],f32[])) tuple(add, p_b.1)
296 }
297
298 condition {
299 p_c = (f32[],(f32[],f32[])) parameter(0)
300 p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0
301 p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1
302 p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1
303 ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT
304 }
305
306 ENTRY entry {
307 const_0 = f32[] constant(0)
308 const_1 = (f32[], f32[]) constant((1, 10))
309 while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1)
310 ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body
311 }
312 )";
313
314 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
315 ParseHloString(hlo_string));
316
317 TF_ASSERT_OK_AND_ASSIGN(bool changed,
318 WhileLoopConstantSinking{}.Run(module.get()));
319 ASSERT_TRUE(changed);
320
321 auto* while_condition = module->GetComputationWithName("condition");
322 EXPECT_THAT(while_condition->root_instruction(),
323 op::Lt(_, op::GetTupleElement(op::Constant())));
324 }
325
TEST_F(WhileLoopConstantSinkingTest,ConditionalDontCreateDeadConstant)326 TEST_F(WhileLoopConstantSinkingTest, ConditionalDontCreateDeadConstant) {
327 const char* const hlo_string = R"(
328 HloModule ModuleWithWhile
329
330 body {
331 p_body = (f32[],f32[],f32[]) parameter(0)
332 p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0
333 const = f32[] constant(1)
334 add = f32[] add(p_body.0, const)
335 p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1
336 p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2
337 ROOT root = (f32[],f32[],f32[]) tuple(add, p_body.1, p_body.2)
338 }
339
340 condition {
341 p_cond = (f32[],f32[],f32[]) parameter(0)
342 p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
343 p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
344 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
345 ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
346 }
347
348 ENTRY entry {
349 const_0 = f32[] constant(0)
350 const_1 = f32[] constant(10)
351 const_2 = f32[] constant(12)
352 while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2)
353 ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body
354 }
355 )";
356
357 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
358 ParseHloString(hlo_string));
359
360 TF_ASSERT_OK_AND_ASSIGN(bool changed,
361 WhileLoopConstantSinking{}.Run(module.get()));
362 ASSERT_TRUE(changed);
363
364 auto* while_condition = module->GetComputationWithName("condition");
365 EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant()));
366 for (const HloInstruction* inst : while_condition->instructions()) {
367 if (inst->opcode() == HloOpcode::kConstant) {
368 EXPECT_GT(inst->user_count(), 0);
369 }
370 }
371 }
372
TEST_F(WhileLoopConstantSinkingTest,ConditionalMultipleSameIndexGTEs)373 TEST_F(WhileLoopConstantSinkingTest, ConditionalMultipleSameIndexGTEs) {
374 const char* const hlo_string = R"(
375 HloModule ModuleWithWhile
376
377 body {
378 p_body = (f32[],f32[],f32[]) parameter(0)
379 p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0
380 const = f32[] constant(1)
381 add.0 = f32[] add(p_body.0, const)
382 p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1
383 add.1 = f32[] add(p_body.1, const)
384 p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2
385 ROOT root = (f32[],f32[],f32[]) tuple(add.0, add.1, p_body.2)
386 }
387
388 condition {
389 p_cond = (f32[],f32[],f32[]) parameter(0)
390 p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
391 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
392 lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT
393 p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
394 p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
395 lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT
396 ROOT result = pred[] and(lt.0, lt.1)
397 }
398
399 ENTRY entry {
400 const_0 = f32[] constant(0)
401 const_1 = f32[] constant(0)
402 const_2 = f32[] constant(12)
403 while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2)
404 ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body
405 }
406 )";
407
408 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
409 ParseHloString(hlo_string));
410 TF_ASSERT_OK_AND_ASSIGN(bool changed,
411 WhileLoopConstantSinking{}.Run(module.get()));
412 ASSERT_TRUE(changed);
413
414 auto* while_condition = module->GetComputationWithName("condition");
415 EXPECT_THAT(while_condition->root_instruction(),
416 op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant())));
417 }
418 } // namespace
419 } // namespace xla
420