1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_parser.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 
23 namespace xla {
24 namespace {
25 
26 namespace op = xla::testing::opcode_matchers;
27 using ::testing::_;
28 
29 class WhileLoopConstantSinkingTest : public ::testing::Test {};
30 
TEST_F(WhileLoopConstantSinkingTest,SinkOneConstant)31 TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) {
32   const char* const hlo_string = R"(
33 HloModule ModuleWithWhile
34 
35 body {
36   p_body = (f32[2],f32[2]) parameter(0)
37   p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
38   p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
39 
40   add.0 = f32[2] add(p_body.0, p_body.1)
41   ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1)
42 }
43 
44 condition {
45   p_cond = (f32[2],f32[2]) parameter(0)
46   ROOT result = pred[] constant(true)
47 }
48 
49 ENTRY entry {
50   const_0 = f32[2] constant({1, 2})
51   const_1 = f32[2] constant({2, 1})
52   while_init = (f32[2],f32[2]) tuple(const_0, const_1)
53   ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body
54 }
55 )";
56 
57   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
58                           ParseHloString(hlo_string));
59 
60   TF_ASSERT_OK_AND_ASSIGN(bool changed,
61                           WhileLoopConstantSinking{}.Run(module.get()));
62   ASSERT_TRUE(changed);
63 
64   auto* while_body = module->GetComputationWithName("body");
65   EXPECT_THAT(while_body->root_instruction(),
66               op::Tuple(op::Add(_, op::Constant()), _));
67 }
68 
TEST_F(WhileLoopConstantSinkingTest,KeepConstantsLoopInvariant)69 TEST_F(WhileLoopConstantSinkingTest, KeepConstantsLoopInvariant) {
70   const char* const hlo_string = R"(
71 HloModule ModuleWithWhile
72 
73 body {
74   p_body = (f32[2],f32[2],f32[2]) parameter(0)
75   p_body.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=0
76   p_body.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=1
77   p_body.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=2
78 
79   add.0 = f32[2] add(p_body.1, p_body.2)
80   ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_body.1, p_body.2)
81 }
82 
83 condition {
84   p_cond = (f32[2],f32[2],f32[2]) parameter(0)
85   ROOT result = pred[] constant(true)
86 }
87 
88 ENTRY entry {
89   const_0 = f32[2] constant({1, 2})
90   const_1 = f32[2] constant({2, 1})
91   const_2 = f32[2] constant({3, 1})
92   while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2)
93   ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body
94 }
95 )";
96 
97   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
98                           ParseHloString(hlo_string));
99 
100   TF_ASSERT_OK_AND_ASSIGN(bool changed,
101                           WhileLoopConstantSinking{}.Run(module.get()));
102   ASSERT_TRUE(changed);
103 
104   auto* while_body = module->GetComputationWithName("body");
105   EXPECT_THAT(while_body->root_instruction(),
106               op::Tuple(op::Add(op::Constant(), op::Constant()),
107                         op::GetTupleElement(op::Parameter(0)),
108                         op::GetTupleElement(op::Parameter(0))));
109 }
110 
TEST_F(WhileLoopConstantSinkingTest,TupleShapedConstants)111 TEST_F(WhileLoopConstantSinkingTest, TupleShapedConstants) {
112   const char* const hlo_string = R"(
113 HloModule ModuleWithWhile
114 
115 body {
116   p_b = (f32[2],(f32[2],f32[2])) parameter(0)
117   p_b.0 = f32[2] get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=0
118   p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1
119 
120   p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0
121 
122   ROOT root = (f32[2],f32[2],f32[2]) tuple(p_b.1.1, p_b.1)
123 }
124 
125 condition {
126   p_cond = (f32[2],(f32[2],f32[2])) parameter(0)
127   ROOT result = pred[] constant(true)
128 }
129 
130 ENTRY entry {
131   const_0 = f32[2] constant({1, 2})
132   const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1}))
133   while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1)
134   ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body
135 }
136 )";
137 
138   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
139                           ParseHloString(hlo_string));
140 
141   TF_ASSERT_OK_AND_ASSIGN(bool changed,
142                           WhileLoopConstantSinking{}.Run(module.get()));
143   ASSERT_TRUE(changed);
144 
145   auto* while_body = module->GetComputationWithName("body");
146   EXPECT_THAT(while_body->root_instruction(),
147               op::Tuple(op::GetTupleElement(op::Constant(), 0),
148                         op::GetTupleElement(op::Parameter(0))));
149 }
150 
TEST_F(WhileLoopConstantSinkingTest,DuplicateGTEs)151 TEST_F(WhileLoopConstantSinkingTest, DuplicateGTEs) {
152   // This test shows that the pass fails to optimize non-canonical IR.
153   //
154   // Even though the input IR has a constant value for p_b.2.dup,
155   // WhileLoopConstantSinking doesn't try to detect this.  Instead, it relies on
156   // prior runs of HLO CSE to have commoned these identical GTE instructions.
157 
158   const char* const hlo_string = R"(
159 HloModule ModuleWithWhile
160 
161 body {
162   p_b = (f32[2],f32[2],f32[2]) parameter(0)
163 
164   p_b.1     = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=1
165   p_b.2     = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2
166   p_b.2.dup = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2
167 
168   add.0 = f32[2] add(p_b.1, p_b.2.dup)
169   ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_b.1, p_b.2)
170 }
171 
172 condition {
173   p_cond = (f32[2],f32[2],f32[2]) parameter(0)
174   ROOT result = pred[] constant(true)
175 }
176 
177 ENTRY entry {
178   const_0 = f32[2] constant({1, 2})
179   const_1 = f32[2] constant({2, 1})
180   const_2 = f32[2] constant({3, 1})
181   while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2)
182   ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body
183 }
184 )";
185 
186   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
187                           ParseHloString(hlo_string));
188 
189   TF_ASSERT_OK_AND_ASSIGN(bool changed,
190                           WhileLoopConstantSinking{}.Run(module.get()));
191   ASSERT_TRUE(changed);
192 
193   auto* while_body = module->GetComputationWithName("body");
194   EXPECT_THAT(while_body->root_instruction(),
195               op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())),
196                         op::GetTupleElement(op::Parameter(0)),
197                         op::GetTupleElement(op::Parameter(0))));
198 }
199 
TEST_F(WhileLoopConstantSinkingTest,DontCreateDeadConstant)200 TEST_F(WhileLoopConstantSinkingTest, DontCreateDeadConstant) {
201   const char* const hlo_string = R"(
202 HloModule ModuleWithWhile
203 
204 body {
205   p_body = (f32[2],f32[2]) parameter(0)
206   p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
207   p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
208 
209   token0 = token[] after-all()
210   outfeed = token[] outfeed(p_body.0, token0)
211   ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1)
212 }
213 
214 condition {
215   p_cond = (f32[2],f32[2]) parameter(0)
216   ROOT result = pred[] constant(true)
217 }
218 
219 ENTRY entry {
220   const_0 = f32[2] constant({1, 2})
221   const_1 = f32[2] constant({2, 1})
222   while_init = (f32[2],f32[2]) tuple(const_0, const_1)
223   ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition,
224                                       body=body
225 }
226 )";
227 
228   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
229                           ParseHloString(hlo_string));
230 
231   TF_ASSERT_OK_AND_ASSIGN(bool changed,
232                           WhileLoopConstantSinking{}.Run(module.get()));
233   ASSERT_TRUE(changed);
234 
235   auto* while_body = module->GetComputationWithName("body");
236   EXPECT_THAT(while_body->root_instruction(),
237               op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
238                         op::GetTupleElement()));
239   for (const HloInstruction* inst : while_body->instructions()) {
240     if (inst->opcode() == HloOpcode::kConstant) {
241       EXPECT_GT(inst->user_count(), 0);
242     }
243   }
244 }
245 
TEST_F(WhileLoopConstantSinkingTest,ConditionalSinkConstant)246 TEST_F(WhileLoopConstantSinkingTest, ConditionalSinkConstant) {
247   const char* const hlo_string = R"(
248 HloModule ModuleWithWhile
249 
250 body {
251   p_body = (f32[],f32[]) parameter(0)
252   p_body.0 = f32[] get-tuple-element((f32[],f32[]) p_body), index=0
253   const = f32[] constant(1)
254   add = f32[] add(p_body.0, const)
255   p_body.1 = f32[] get-tuple-element((f32[],f32[]) p_body), index=1
256   ROOT root = (f32[],f32[]) tuple(add, p_body.1)
257 }
258 
259 condition {
260   p_cond = (f32[],f32[]) parameter(0)
261   p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0
262   p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1
263   ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
264 }
265 
266 ENTRY entry {
267   const_0 = f32[] constant(0)
268   const_1 = f32[] constant(10)
269   while_init = (f32[],f32[]) tuple(const_0, const_1)
270   ROOT while = (f32[],f32[]) while(while_init), condition=condition, body=body
271 }
272 )";
273 
274   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
275                           ParseHloString(hlo_string));
276 
277   TF_ASSERT_OK_AND_ASSIGN(bool changed,
278                           WhileLoopConstantSinking{}.Run(module.get()));
279   ASSERT_TRUE(changed);
280 
281   auto* while_condition = module->GetComputationWithName("condition");
282   EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant()));
283 }
284 
TEST_F(WhileLoopConstantSinkingTest,ConditionalTupleShapedConstants)285 TEST_F(WhileLoopConstantSinkingTest, ConditionalTupleShapedConstants) {
286   const char* const hlo_string = R"(
287 HloModule ModuleWithWhile
288 
289 body {
290   p_b = (f32[],(f32[],f32[])) parameter(0)
291   p_b.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_b), index=0
292   p_b.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_b), index=1
293   p_b.1.0 = f32[] get-tuple-element((f32[],f32[]) p_b.1), index=0
294   add = f32[] add(p_b.0, p_b.1.0)
295   ROOT root = (f32[],(f32[],f32[])) tuple(add, p_b.1)
296 }
297 
298 condition {
299   p_c = (f32[],(f32[],f32[])) parameter(0)
300   p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0
301   p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1
302   p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1
303   ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT
304 }
305 
306 ENTRY entry {
307   const_0 = f32[] constant(0)
308   const_1 = (f32[], f32[]) constant((1, 10))
309   while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1)
310   ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body
311 }
312 )";
313 
314   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
315                           ParseHloString(hlo_string));
316 
317   TF_ASSERT_OK_AND_ASSIGN(bool changed,
318                           WhileLoopConstantSinking{}.Run(module.get()));
319   ASSERT_TRUE(changed);
320 
321   auto* while_condition = module->GetComputationWithName("condition");
322   EXPECT_THAT(while_condition->root_instruction(),
323               op::Lt(_, op::GetTupleElement(op::Constant())));
324 }
325 
TEST_F(WhileLoopConstantSinkingTest,ConditionalDontCreateDeadConstant)326 TEST_F(WhileLoopConstantSinkingTest, ConditionalDontCreateDeadConstant) {
327   const char* const hlo_string = R"(
328 HloModule ModuleWithWhile
329 
330 body {
331   p_body = (f32[],f32[],f32[]) parameter(0)
332   p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0
333   const = f32[] constant(1)
334   add = f32[] add(p_body.0, const)
335   p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1
336   p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2
337   ROOT root = (f32[],f32[],f32[]) tuple(add, p_body.1, p_body.2)
338 }
339 
340 condition {
341   p_cond = (f32[],f32[],f32[]) parameter(0)
342   p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
343   p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
344   p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
345   ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
346 }
347 
348 ENTRY entry {
349   const_0 = f32[] constant(0)
350   const_1 = f32[] constant(10)
351   const_2 = f32[] constant(12)
352   while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2)
353   ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body
354 }
355 )";
356 
357   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
358                           ParseHloString(hlo_string));
359 
360   TF_ASSERT_OK_AND_ASSIGN(bool changed,
361                           WhileLoopConstantSinking{}.Run(module.get()));
362   ASSERT_TRUE(changed);
363 
364   auto* while_condition = module->GetComputationWithName("condition");
365   EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant()));
366   for (const HloInstruction* inst : while_condition->instructions()) {
367     if (inst->opcode() == HloOpcode::kConstant) {
368       EXPECT_GT(inst->user_count(), 0);
369     }
370   }
371 }
372 
TEST_F(WhileLoopConstantSinkingTest,ConditionalMultipleSameIndexGTEs)373 TEST_F(WhileLoopConstantSinkingTest, ConditionalMultipleSameIndexGTEs) {
374   const char* const hlo_string = R"(
375 HloModule ModuleWithWhile
376 
377 body {
378   p_body = (f32[],f32[],f32[]) parameter(0)
379   p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0
380   const = f32[] constant(1)
381   add.0 = f32[] add(p_body.0, const)
382   p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1
383   add.1 = f32[] add(p_body.1, const)
384   p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2
385   ROOT root = (f32[],f32[],f32[]) tuple(add.0, add.1, p_body.2)
386 }
387 
388 condition {
389   p_cond = (f32[],f32[],f32[]) parameter(0)
390   p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
391   p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
392   lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT
393   p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
394   p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
395   lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT
396   ROOT result = pred[] and(lt.0, lt.1)
397 }
398 
399 ENTRY entry {
400   const_0 = f32[] constant(0)
401   const_1 = f32[] constant(0)
402   const_2 = f32[] constant(12)
403   while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2)
404   ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body
405 }
406 )";
407 
408   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
409                           ParseHloString(hlo_string));
410   TF_ASSERT_OK_AND_ASSIGN(bool changed,
411                           WhileLoopConstantSinking{}.Run(module.get()));
412   ASSERT_TRUE(changed);
413 
414   auto* while_condition = module->GetComputationWithName("condition");
415   EXPECT_THAT(while_condition->root_instruction(),
416               op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant())));
417 }
418 }  // namespace
419 }  // namespace xla
420